vp_suite.models.predrnn_v2

class PredRNN_V2(device, **model_kwargs)

Bases: vp_suite.base.base_model.VPModel

This is a reimplementation of the model “PredRNN-V2”, as introduced in “PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning” by Wang et al. (https://arxiv.org/pdf/2103.09504.pdf). This implementation is based on the official PyTorch implementation on https://github.com/thuml/predrnn-pytorch.

PredRNN-V aims at learning partly disentangled spatial/temporal dynamics of the input domain and use this to render more accurate predicted frames. The “Spatio-Temporal LSTM Cell” (ST cell) forms the heart of this model.

Note

This model will use the whole frame sequence as an input, including the frames to be predicted. If you do not have a ground truth prediction for your frame sequence, pad the sequence with t “zero” frames, with t being the amount of predicted frames. Also: The original action-conditonal implementations are broken: if using the action-conditional variant, reverse scheduled sampling as well as ‘conv_on_input’ has to be set to 1/True!

CAN_HANDLE_ACTIONS = False

Whether the model can handle actions or not.

CODE_REFERENCE = 'https://github.com/thuml/predrnn-pytorch'

The code location of the reference implementation.

MATCHES_REFERENCE: str = 'Yes'

A comment indicating whether the implementation in this package matches the reference.

NAME = 'PredRNN++'

The model’s name.

NEEDS_COMPLETE_INPUT = True

Whether the input sequences also need to include the to-be-predicted frames.

PAPER_REFERENCE = 'https://arxiv.org/abs/2103.09504'

The publication where this model was introduced first.

__init__(device, **model_kwargs)

Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters

Parameters
  • device (str) – The device identifier for the module.

  • **model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.

conv_actions_on_input: bool = True

Whether to convolve actions directly on the input

decoupling_loss_scale = 100.0

The scaling factor for the decoupling loss

filter_size = 5

Kernel size for ST cell and action-conditional convs

forward(x, pred_frames=1, **kwargs)

Given an input sequence of t frames, predicts pred_frames (p) frames into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • pred_frames (int) – The number of frames to predict into the future.

  • () (**kwargs) –

Returns: A batch of sequences of p predicted frames as a tensor of shape [b, p, c, h, w].

inflated_action_dim = 3

Dimensionality of the ‘inflated actions’ (actions that have been transformed to tensors)

layer_norm: bool = False

Whether to use layer normalization in the ST cells

num_hidden = [128, 128, 128, 128]

Hidden layer dimensionality per ST cell layer

num_layers = 3

Number of ST Cell layers

patch_size = 4

During encoding, the image is sliced into patches of this size (height and width)

pred_1(x, **kwargs)

Given an input sequence of t frames, predicts one single frame into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • **kwargs (Any) – Optional input parameters such as actions.

Returns: A single frame as a tensor of shape [b, c, h, w].

r_exp_alpha: int = 5000

Reverse scheduled sampling rate change regulator factor

r_sampling_step_1: int = 25000

At which iteration to proceed to second reverse scheduled sampling phase

r_sampling_step_2: int = 50000

At which iteration to proceed to third reverse scheduled sampling phase

residual_on_action_conv: bool = True

Whether to use residual connections for the direct action convolution

reverse_input: bool = True

Whether to also train on the reversed version of training sequences

reverse_scheduled_sampling: bool = False

Whether to use reverse scheduled sampling

sampling_changing_rate = 2e-05

Per-iteration changing rate for scheduled sampling

sampling_eta: float = None

Sampling rate

sampling_stop_iter: int = 50000

At which iteration to stop the scheduled sampling

scheduled_sampling: bool = True

Whether to use scheduled sampling during training

stride = 1

Stride for ST cell

train_iter(config, loader, optimizer, loss_provider, epoch)

PredRNN++’s training iteration utilizes reversed input and keeps track of the number of training iterations done so far in order to adjust the sampling schedule. Otherwise, the iteration logic is the same as in the default train_iter() function.

Parameters
  • config (dict) – The configuration dict of the current training run (combines model, dataset and run config)

  • loader (DataLoader) – Training data is sampled from this loader.

  • optimizer (Optimizer) – The optimizer to use for weight update calculations.

  • loss_provider (PredictionLossProvider) – An instance of the LossProvider class for flexible loss calculation.

  • epoch (int) – The current epoch.

training_iteration: int = None

Current number of training iteration (~how many training inferences were done so far)