vp_suite.base.base_model

class VPModel(device, **model_kwargs)

Bases: torch.nn.modules.module.Module

The base class for all video prediction models. Each model ought to provide two forward pass/prediction methods (the default self.forward() method and pred_1(), which predicts a single frame) as well as two utility methods (train_iter() for a single training epoch on a given dataset loader and, analogously, eval_iter() for a single epoch of validation iteration).

CAN_HANDLE_ACTIONS = False

Whether the model can handle actions or not.

CODE_REFERENCE = None

The code location of the reference implementation.

MATCHES_REFERENCE: str = None

A comment indicating whether the implementation in this package matches the reference.

MIN_CONTEXT_FRAMES = 1

Minimum number of context frames required for the model to work. By default, models will be able to deal with any number of context frames.

NAME = None

The model’s name.

NEEDS_COMPLETE_INPUT = False

Whether the input sequences also need to include the to-be-predicted frames.

NON_CONFIG_VARS = ['functions', 'model_dir', 'dump_patches', 'training']

Variables that do not get included in the dict returned by self.config() (Constants are not included either).

PAPER_REFERENCE = None

The publication where this model was introduced first.

REQUIRED_ARGS = ['img_shape', 'action_size', 'tensor_value_range']

The attributes that the model creator needs to supply when creating the model.

TRAINABLE = True

Whether the model is trainable or not.

__init__(device, **model_kwargs)

Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters

Parameters
  • device (str) – The device identifier for the module.

  • **model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.

action_conditional = False

True if this model is leveraging input actions for the predictions, False otherwise.

action_size = None

The expected dimensionality of the action inputs.

property config

A dictionary containing the complete model configuration, including common attributes as well as model-specific attributes.

Type

Returns

eval_iter(config, loader, loss_provider)

Default training iteration: Loops through the whole data loader once and, for every datapoint, executes forward pass, and loss calculation. Then, aggregates all loss values to assess the prediction quality.

Parameters
  • config (dict) – The configuration dict of the current validation run (combines model, dataset and run config)

  • loader (DataLoader) – Validation data is sampled from this loader.

  • loss_provider (PredictionLossProvider) – An instance of the LossProvider class for flexible loss calculation.

Returns: A dictionary containing the averages value for each loss type specified for usage, as well as the value for the ‘indicator’ loss (the loss used for determining overall model improvement).

forward(x, pred_frames=1, **kwargs)

Given an input sequence of t frames, predicts pred_frames (p) frames into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • pred_frames (int) – The number of frames to predict into the future.

  • () (**kwargs) –

Returns: A batch of sequences of p predicted frames as a tensor of shape [b, p, c, h, w].

img_shape = None
model_dir = None

The save location of model.

pred_1(x, **kwargs)

Given an input sequence of t frames, predicts one single frame into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • **kwargs (Any) – Optional input parameters such as actions.

Returns: A single frame as a tensor of shape [b, c, h, w].

tensor_value_range = None

The expected value range of the input tensors.

train_iter(config, loader, optimizer, loss_provider, epoch)

Default training iteration: Loops through the whole data loader once and, for every batch, executes forward pass, loss calculation and backward pass/optimization step.

Parameters
  • config (dict) – The configuration dict of the current training run (combines model, dataset and run config)

  • loader (DataLoader) – Training data is sampled from this loader.

  • optimizer (Optimizer) – The optimizer to use for weight update calculations.

  • loss_provider (PredictionLossProvider) – An instance of the LossProvider class for flexible loss calculation.

  • epoch (int) – The current epoch.

unpack_data(data, config, reverse=False, complete=False)

Extracts inputs and targets from a data blob.

Parameters
  • data (VPData) – The given VPData data blob/dictionary containing frames and actions.

  • config (dict) – The current run configuration specifying how to extract the data from the given data blob.

  • reverse (bool) – If specified, reverses the input first

  • complete (bool) – If specified, input_frames will also contain the to-be-predicted frames (just like with NEEDS_COMPLETE_INPUT)

Returns: The specified amount of input/target frames as well as the actions. All inputs will come in the shape the model expects as input later.