vp_suite.measure.fvd.fvd
- class FrechetVideoDistance(device, in_channels=3)
Bases:
vp_suite.base.base_measure.VPMeasure
This measure calculates the Frechet Video Distance, as introduced in Unterthiner et al. (https://arxiv.org/abs/1812.01717). The Frechet Distance is a similarity measure between two curves, and the Frechet Video Distance transfers this idea to assess the perceptual quality of generated videos with respect to a ground truth sequence by comparing video features obtained by passing the videos to an InceptionI3D Network.
Code is inspired by: https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
Note
The Frechet Video Distance calculation code is differentiable, meaning that this version of FVD can also be used as a loss!
- REFERENCE: str = 'https://arxiv.org/abs/1812.01717'
The reference publication where this measure is originally introduced (represented as string)
- __init__(device, in_channels=3)
Instantiates the FVD by setting the device and initializing the InceptionI3D module, which is used to extract the features that shall be compared.
- calculate_n_chunks(num_frames)
If given input length is too large, this function returns the number of chunks. Each chunk is then used for a separate fvd calculation, and their results are combined afterwards.
- Parameters
num_frames (int) – The number of context frames (aka the input length).
- Returns
The number of chunks the input sequence needs to be split into, as well as a boolean value indicating whether the last chunk has to be neglected.
- forward(pred, target)
The module’s forward pass takes the predicted frame sequence and the ground truth, compares them based on the deriving measure’s criterion and logic and outputs a numerical assessment of the prediction quality.
The base measure’s forward method can be used by deriving classes and simply applies the criterion to the input tensors, sums up over all entries of an image and finally averages over frames and then batches.
- Parameters
pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).
target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)
Returns: The calculated numerical quality assessment.
- get_distance(pred, target)
Calculates the Frechet Video Distance between the provided chunked prediction and the ground truth tensors, by first extracting perceptual features from the InceptionI3D Network before calculating the 2-Wasserstein-Distance on these features. The video frames have been previously resized to meet the height and width constraints of the I3D Network.
- Parameters
pred (torch.Tensor) – The chunked predicted frame sequence as a 5D tensor (batch, c, chunk_l, h, w).
target (torch.Tensor) – The chunked ground truth frame sequence as a 5D tensor (batch, c, chunk_l, h, w).
Returns: The calculated 2-Wasserstein metric as a scalar tensor.
- calculate_2_wasserstein_dist(pred, target)
Calulates the two components of the 2-Wasserstein metric: The general formula is given by:
d(P_target, P_pred = min_{X, Y} E[|X-Y|^2]
For multivariate gaussian distributed inputs
x_target ~ MN(mu_target, cov_target)
andx_pred ~ MN(mu_pred, cov_pred)
, this reduces to:d = |mu_target - mu_pred|^2 - Tr(cov_target + cov_pred - 2(cov_target * cov_pred)^(1/2))
Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf
Input shape: [b = batch_size, n = num_features] Output shape: scalar
- Parameters
pred (torch.Tensor) – The logits of the chunked prediction extracted by the InceptionI3D network (batch, n_feat).
target (torch.Tensor) – The logits of the chunked ground truth extracted by the I3D network (batch, n_feat).
Returns: The calculated 2-Wasserstein metric as a scalar tensor.