vp_suite.base.base_measure

class VPMeasure(device)

Bases: torch.nn.modules.module.Module

The base class for all measures (nn Modules taking as input a ground truth sequence and a predicted sequence and providing a numerical assessment of the prediction quality). Measures can be losses and/or metrics, depending on their registration status in the base package’s __init__.py file. All implemented losses and metrics should subclass this class.

device

A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).

Type

str

Note

All measures that should be usable as losses should return values where lower means better. If for the specific measure higher means better, the actual value should be inverted in the forward method and inverted again in the display() method (which prepares the value for display to humans).

BIGGER_IS_BETTER = False

Specifies whether bigger values are better.

NAME: str = NotImplemented

The clear-text name of the measure.

OPT_VALUE = 0.0

Specifies the best value attainable (e.g. when input tensors are equal).

REFERENCE: str = None

The reference publication where this measure is originally introduced (represented as string)

__init__(device)

Instantiates the measure class by setting the device. Additionally, for the derived measure classes, instantiates the criterion that is used to calculate the measure value.

Parameters

device (str) – A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).

forward(pred, target)

The module’s forward pass takes the predicted frame sequence and the ground truth, compares them based on the deriving measure’s criterion and logic and outputs a numerical assessment of the prediction quality.

The base measure’s forward method can be used by deriving classes and simply applies the criterion to the input tensors, sums up over all entries of an image and finally averages over frames and then batches.

Parameters
  • pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).

  • target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)

Returns: The calculated numerical quality assessment.

reshape_clamp(pred, target)

Reshapes and clamps the input tensors, returning a 4D tensor where batch and time dimension are combined.

Parameters
  • pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).

  • target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)

Returns: the reshaped and clamped pred and target tensors.

classmethod to_display(x)

Converts a measurement value from the lower-is-better representation returned by the forward() method to the actual representation of the measure (e.g. SSIM having its best value at 1.0). If the measure did not get inverted in the forward(), this method just returns the input value.

Parameters

x (float) – The value to be converted.

Returns: The converted value.