vp_suite.models.copy_last_frame

class CopyLastFrame(device=None, **model_kwargs)

Bases: vp_suite.base.base_model.VPModel

A simple, non-trainable baseline model that simply returns the latest frame as the next predicted frame.

NAME = 'CopyLastFrame'

The model’s name.

REQUIRED_ARGS = []

The attributes that the model creator needs to supply when creating the model.

TRAINABLE = False

Whether the model is trainable or not.

__init__(device=None, **model_kwargs)

Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters

Parameters
  • device (str) – The device identifier for the module.

  • **model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.

pred_1(x, **kwargs)

Given an input sequence of t frames, predicts one single frame into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • **kwargs (Any) – Optional input parameters such as actions.

Returns: A single frame as a tensor of shape [b, c, h, w].

training: bool