vp_suite.measure.metric_provider
- class PredictionMetricProvider(config)
Bases:
object
This class provides bundled access to multiple metrics. With this class’s
get_metrics()
method, all specified metric scores are calculated on the same input prediction and target tensor.- device
A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).
- Type
- available_metrics
A dictionary containing the string identifiers and corresponding metrics that the metric provider should use when provided with input tensors.
- Type
- metrics
The concrete instantiated metrics that the metric provider uses when provided with input tensors.
- Type
- __init__(config)
Initializes the provider by extracting device and metric IDs from the provided config dict and instantiating the metrics that shall be used.
- Parameters
config (dict) – A dictionary containing the devices and metrics to use.
- get_metrics(pred, target, frames=None, all_frame_cnts=False)
Takes in tensors of predicted frames and the corresponding ground truth and calculates the metric scores for the metrics instantiated previously.
- Parameters
pred (torch.Tensor) – The predicted frame sequence as a 5D float tensor (batch, frames, c, h, w).
target (torch.Tensor) – The ground truth frame sequence as a 5D float tensor (batch, frames, c, h, w)
frames (int) – If frames is specified, only considers the first ‘frames’ frames.
all_frame_cnts (bool) – If set to true, elicits metrics for all prediction horizons from 1 up to the maximum number of frames. Otherwise, just elicits metrics for the specified number of frames
- Returns
A list of dictionaries, where each dictionary contains the metric ids and the corresponding result value for a specific number of prediction frames.