vp_suite.base.base_measure
- class VPMeasure(device)
Bases:
torch.nn.modules.module.Module
The base class for all measures (nn Modules taking as input a ground truth sequence and a predicted sequence and providing a numerical assessment of the prediction quality). Measures can be losses and/or metrics, depending on their registration status in the base package’s __init__.py file. All implemented losses and metrics should subclass this class.
- device
A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).
- Type
Note
All measures that should be usable as losses should return values where lower means better. If for the specific measure higher means better, the actual value should be inverted in the forward method and inverted again in the
display()
method (which prepares the value for display to humans).- BIGGER_IS_BETTER = False
Specifies whether bigger values are better.
- OPT_VALUE = 0.0
Specifies the best value attainable (e.g. when input tensors are equal).
- REFERENCE: str = None
The reference publication where this measure is originally introduced (represented as string)
- __init__(device)
Instantiates the measure class by setting the device. Additionally, for the derived measure classes, instantiates the criterion that is used to calculate the measure value.
- Parameters
device (str) – A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).
- forward(pred, target)
The module’s forward pass takes the predicted frame sequence and the ground truth, compares them based on the deriving measure’s criterion and logic and outputs a numerical assessment of the prediction quality.
The base measure’s forward method can be used by deriving classes and simply applies the criterion to the input tensors, sums up over all entries of an image and finally averages over frames and then batches.
- Parameters
pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).
target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)
Returns: The calculated numerical quality assessment.
- reshape_clamp(pred, target)
Reshapes and clamps the input tensors, returning a 4D tensor where batch and time dimension are combined.
- Parameters
pred (torch.Tensor) – The predicted frame sequence as a 5D tensor (batch, frames, c, h, w).
target (torch.Tensor) – The ground truth frame sequence as a 5D tensor (batch, frames, c, h, w)
Returns: the reshaped and clamped pred and target tensors.
- classmethod to_display(x)
Converts a measurement value from the lower-is-better representation returned by the
forward()
method to the actual representation of the measure (e.g. SSIM having its best value at 1.0). If the measure did not get inverted in theforward()
, this method just returns the input value.- Parameters
x (float) – The value to be converted.
Returns: The converted value.