vp_suite.measure.metric_provider

class PredictionMetricProvider(config)

Bases: object

This class provides bundled access to multiple metrics. With this class’s get_metrics() method, all specified metric scores are calculated on the same input prediction and target tensor.

device

A string specifying whether to use the GPU for calculations (cuda) or the CPU (cpu).

Type

str

available_metrics

A dictionary containing the string identifiers and corresponding metrics that the metric provider should use when provided with input tensors.

Type

dict

metrics

The concrete instantiated metrics that the metric provider uses when provided with input tensors.

Type

dict

__init__(config)

Initializes the provider by extracting device and metric IDs from the provided config dict and instantiating the metrics that shall be used.

Parameters

config (dict) – A dictionary containing the devices and metrics to use.

get_metrics(pred, target, frames=None, all_frame_cnts=False)

Takes in tensors of predicted frames and the corresponding ground truth and calculates the metric scores for the metrics instantiated previously.

Parameters
  • pred (torch.Tensor) – The predicted frame sequence as a 5D float tensor (batch, frames, c, h, w).

  • target (torch.Tensor) – The ground truth frame sequence as a 5D float tensor (batch, frames, c, h, w)

  • frames (int) – If frames is specified, only considers the first ‘frames’ frames.

  • all_frame_cnts (bool) – If set to true, elicits metrics for all prediction horizons from 1 up to the maximum number of frames. Otherwise, just elicits metrics for the specified number of frames

Returns

A list of dictionaries, where each dictionary contains the metric ids and the corresponding result value for a specific number of prediction frames.