vp_suite.models.predrnn_v2
- class PredRNN_V2(device, **model_kwargs)
Bases:
vp_suite.base.base_model.VPModel
This is a reimplementation of the model “PredRNN-V2”, as introduced in “PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning” by Wang et al. (https://arxiv.org/pdf/2103.09504.pdf). This implementation is based on the official PyTorch implementation on https://github.com/thuml/predrnn-pytorch.
PredRNN-V aims at learning partly disentangled spatial/temporal dynamics of the input domain and use this to render more accurate predicted frames. The “Spatio-Temporal LSTM Cell” (ST cell) forms the heart of this model.
Note
This model will use the whole frame sequence as an input, including the frames to be predicted. If you do not have a ground truth prediction for your frame sequence, pad the sequence with t “zero” frames, with t being the amount of predicted frames. Also: The original action-conditonal implementations are broken: if using the action-conditional variant, reverse scheduled sampling as well as ‘conv_on_input’ has to be set to 1/True!
- CAN_HANDLE_ACTIONS = False
Whether the model can handle actions or not.
- CODE_REFERENCE = 'https://github.com/thuml/predrnn-pytorch'
The code location of the reference implementation.
- MATCHES_REFERENCE: str = 'Yes'
A comment indicating whether the implementation in this package matches the reference.
- NAME = 'PredRNN++'
The model’s name.
- NEEDS_COMPLETE_INPUT = True
Whether the input sequences also need to include the to-be-predicted frames.
- PAPER_REFERENCE = 'https://arxiv.org/abs/2103.09504'
The publication where this model was introduced first.
- __init__(device, **model_kwargs)
Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters
- Parameters
device (str) – The device identifier for the module.
**model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.
- decoupling_loss_scale = 100.0
The scaling factor for the decoupling loss
- filter_size = 5
Kernel size for ST cell and action-conditional convs
- forward(x, pred_frames=1, **kwargs)
Given an input sequence of t frames, predicts pred_frames (p) frames into the future.
- Parameters
x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].
pred_frames (int) – The number of frames to predict into the future.
() (**kwargs) –
Returns: A batch of sequences of p predicted frames as a tensor of shape [b, p, c, h, w].
- inflated_action_dim = 3
Dimensionality of the ‘inflated actions’ (actions that have been transformed to tensors)
Hidden layer dimensionality per ST cell layer
- num_layers = 3
Number of ST Cell layers
- patch_size = 4
During encoding, the image is sliced into patches of this size (height and width)
- pred_1(x, **kwargs)
Given an input sequence of t frames, predicts one single frame into the future.
- Parameters
x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].
**kwargs (Any) – Optional input parameters such as actions.
Returns: A single frame as a tensor of shape [b, c, h, w].
- r_sampling_step_1: int = 25000
At which iteration to proceed to second reverse scheduled sampling phase
- r_sampling_step_2: int = 50000
At which iteration to proceed to third reverse scheduled sampling phase
- residual_on_action_conv: bool = True
Whether to use residual connections for the direct action convolution
- sampling_changing_rate = 2e-05
Per-iteration changing rate for scheduled sampling
- stride = 1
Stride for ST cell
- train_iter(config, loader, optimizer, loss_provider, epoch)
PredRNN++’s training iteration utilizes reversed input and keeps track of the number of training iterations done so far in order to adjust the sampling schedule. Otherwise, the iteration logic is the same as in the default
train_iter()
function.- Parameters
config (dict) – The configuration dict of the current training run (combines model, dataset and run config)
loader (DataLoader) – Training data is sampled from this loader.
optimizer (Optimizer) – The optimizer to use for weight update calculations.
loss_provider (PredictionLossProvider) – An instance of the
LossProvider
class for flexible loss calculation.epoch (int) – The current epoch.