vp_suite.measure.fvd.fvd

class FrechetVideoDistance(device, in_channels=3)

Bases: vp_suite.base.base_measure.VPMeasure

This measure calculates the Frechet Video Distance, as introduced in Unterthiner et al. (https://arxiv.org/abs/1812.01717). The Frechet Distance is a similarity measure between two curves, and the Frechet Video Distance transfers this idea to assess the perceptual quality of generated videos with respect to a ground truth sequence by comparing video features obtained by passing the videos to an InceptionI3D Network.

Code is inspired by: https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py

Note

The Frechet Video Distance calculation code is differentiable, meaning that this version of FVD can also be used as a loss!

NAME: str = 'Fréchet Video Distance (FVD)'

The clear-text name of the measure.

REFERENCE: str = 'https://arxiv.org/abs/1812.01717'

The reference publication where this measure is originally introduced (represented as string)

__init__(device, in_channels=3)

Instantiates the FVD by setting the device and initializing the InceptionI3D module, which is used to extract the features that shall be compared.

Parameters
  • device (str) – A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).

  • in_channels (int) – Number of input channels (Supported: 2 or 3)

calculate_n_chunks(num_frames)

If given input length is too large, this function returns the number of chunks. Each chunk is then used for a separate fvd calculation, and their results are combined afterwards.

Parameters

num_frames (int) – The number of context frames (aka the input length).

Returns

The number of chunks the input sequence needs to be split into, as well as a boolean value indicating whether the last chunk has to be neglected.

forward(pred, target)

The module’s forward pass takes the predicted frame sequence and the ground truth, compares them based on the deriving measure’s criterion and logic and outputs a numerical assessment of the prediction quality.

The base measure’s forward method can be used by deriving classes and simply applies the criterion to the input tensors, sums up over all entries of an image and finally averages over frames and then batches.

Parameters
  • pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).

  • target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)

Returns: The calculated numerical quality assessment.

get_distance(pred, target)

Calculates the Frechet Video Distance between the provided chunked prediction and the ground truth tensors, by first extracting perceptual features from the InceptionI3D Network before calculating the 2-Wasserstein-Distance on these features. The video frames have been previously resized to meet the height and width constraints of the I3D Network.

Parameters
  • pred (torch.Tensor) – The chunked predicted frame sequence as a 5D tensor (batch, c, chunk_l, h, w).

  • target (torch.Tensor) – The chunked ground truth frame sequence as a 5D tensor (batch, c, chunk_l, h, w).

Returns: The calculated 2-Wasserstein metric as a scalar tensor.

training: bool
calculate_2_wasserstein_dist(pred, target)

Calulates the two components of the 2-Wasserstein metric: The general formula is given by: d(P_target, P_pred = min_{X, Y} E[|X-Y|^2]

For multivariate gaussian distributed inputs x_target ~ MN(mu_target, cov_target) and x_pred ~ MN(mu_pred, cov_pred), this reduces to: d = |mu_target - mu_pred|^2 - Tr(cov_target + cov_pred - 2(cov_target * cov_pred)^(1/2))

Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf

Input shape: [b = batch_size, n = num_features] Output shape: scalar

Parameters
  • pred (torch.Tensor) – The logits of the chunked prediction extracted by the InceptionI3D network (batch, n_feat).

  • target (torch.Tensor) – The logits of the chunked ground truth extracted by the I3D network (batch, n_feat).

Returns: The calculated 2-Wasserstein metric as a scalar tensor.