vp_suite.models.st_phy

class STPhy(device, **model_kwargs)

Bases: vp_suite.base.base_model.VPModel

This class implements a hybrid model that aims to unify the advantages of the PhyDNet (Le Guen and Thome, https://arxiv.org/abs/2003.01460, https://github.com/vincent-leguen/PhyDNet) and the PredRNN++ (Wang et al., https://arxiv.org/abs/2103.09504, https://github.com/thuml/predrnn-pytorch) models. More specifically, it replaces PhyDNet’s regular ConvLSTM cells with the ST Cells from PredRNN++ and integrates PhyDNet’s teacher forcing and PredRNN++’s scheduled sampling techniques into training. (TODO adjust model to cohere to this description)

CAN_HANDLE_ACTIONS = True

Whether the model can handle actions or not.

NAME = 'ST-Phy'

The model’s name.

__init__(device, **model_kwargs)

Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters

Parameters
  • device (str) – The device identifier for the module.

  • **model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.

decoupling_loss_scale = 100.0

The scaling factor for the decoupling loss

forward(x, pred_frames=1, **kwargs)

Given an input sequence of t frames, predicts pred_frames (p) frames into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • pred_frames (int) – The number of frames to predict into the future.

  • () (**kwargs) –

Returns: A batch of sequences of p predicted frames as a tensor of shape [b, p, c, h, w].

inflated_action_dim = 3

Dimensionality of the ‘inflated actions’ (actions that have been transformed to tensors)

moment_loss_scale = 1.0

Scaling factor for the moment loss (for PDE-Constrained prediction by the PhyCells)

num_layers = 3

Number of layers (1 PhyCell and 1 ST cell per layer)

phycell_channels = 49

Channel dimensionality for the PhyCells

phycell_kernel_size = (7, 7)

PhyCell kernel size

pred_1(x, **kwargs)

Given an input sequence of t frames, predicts one single frame into the future.

Parameters
  • x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].

  • **kwargs (Any) – Optional input parameters such as actions.

Returns: A single frame as a tensor of shape [b, c, h, w].

st_cell_channels = 64

Hidden layer dimensionality for the ST cell layers

teacher_forcing_decay = 0.003

Per-Episode decrease of the teacher forcing ratio (Starts out at 1.0)

train_iter(config, data_loader, optimizer, loss_provider, epoch)

ST-Phy’s training iteration utilizes a scheduled teacher forcing ratio. Otherwise, the iteration logic is the same as in the default train_iter() function.

Parameters
  • config (dict) – The configuration dict of the current training run (combines model, dataset and run config)

  • data_loader (DataLoader) – Training data is sampled from this loader.

  • optimizer (Optimizer) – The optimizer to use for weight update calculations.

  • loss_provider (PredictionLossProvider) – An instance of the LossProvider class for flexible loss calculation.

  • epoch (int) – The current epoch.

training: bool