vp_suite.base.base_model
- class VPModel(device, **model_kwargs)
Bases:
torch.nn.modules.module.Module
The base class for all video prediction models. Each model ought to provide two forward pass/prediction methods (the default
self.forward()
method andpred_1()
, which predicts a single frame) as well as two utility methods (train_iter()
for a single training epoch on a given dataset loader and, analogously,eval_iter()
for a single epoch of validation iteration).- CAN_HANDLE_ACTIONS = False
Whether the model can handle actions or not.
- CODE_REFERENCE = None
The code location of the reference implementation.
- MATCHES_REFERENCE: str = None
A comment indicating whether the implementation in this package matches the reference.
- MIN_CONTEXT_FRAMES = 1
Minimum number of context frames required for the model to work. By default, models will be able to deal with any number of context frames.
- NAME = None
The model’s name.
- NEEDS_COMPLETE_INPUT = False
Whether the input sequences also need to include the to-be-predicted frames.
- NON_CONFIG_VARS = ['functions', 'model_dir', 'dump_patches', 'training']
Variables that do not get included in the dict returned by
self.config()
(Constants are not included either).
- PAPER_REFERENCE = None
The publication where this model was introduced first.
- REQUIRED_ARGS = ['img_shape', 'action_size', 'tensor_value_range']
The attributes that the model creator needs to supply when creating the model.
- TRAINABLE = True
Whether the model is trainable or not.
- __init__(device, **model_kwargs)
Initializes the model by first setting all model hyperparameters, attributes and the like. Then, the model-specific init will actually create the model from the given hyperparameters
- Parameters
device (str) – The device identifier for the module.
**model_kwargs (Any) – Model arguments such as hyperparameters, input shapes etc.
- action_conditional = False
True if this model is leveraging input actions for the predictions, False otherwise.
- action_size = None
The expected dimensionality of the action inputs.
- property config
A dictionary containing the complete model configuration, including common attributes as well as model-specific attributes.
- Type
Returns
- eval_iter(config, loader, loss_provider)
Default training iteration: Loops through the whole data loader once and, for every datapoint, executes forward pass, and loss calculation. Then, aggregates all loss values to assess the prediction quality.
- Parameters
config (dict) – The configuration dict of the current validation run (combines model, dataset and run config)
loader (DataLoader) – Validation data is sampled from this loader.
loss_provider (PredictionLossProvider) – An instance of the
LossProvider
class for flexible loss calculation.
Returns: A dictionary containing the averages value for each loss type specified for usage, as well as the value for the ‘indicator’ loss (the loss used for determining overall model improvement).
- forward(x, pred_frames=1, **kwargs)
Given an input sequence of t frames, predicts pred_frames (p) frames into the future.
- Parameters
x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].
pred_frames (int) – The number of frames to predict into the future.
() (**kwargs) –
Returns: A batch of sequences of p predicted frames as a tensor of shape [b, p, c, h, w].
- img_shape = None
- model_dir = None
The save location of model.
- pred_1(x, **kwargs)
Given an input sequence of t frames, predicts one single frame into the future.
- Parameters
x (torch.Tensor) – A batch of b sequences of t input frames as a tensor of shape [b, t, c, h, w].
**kwargs (Any) – Optional input parameters such as actions.
Returns: A single frame as a tensor of shape [b, c, h, w].
- tensor_value_range = None
The expected value range of the input tensors.
- train_iter(config, loader, optimizer, loss_provider, epoch)
Default training iteration: Loops through the whole data loader once and, for every batch, executes forward pass, loss calculation and backward pass/optimization step.
- Parameters
config (dict) – The configuration dict of the current training run (combines model, dataset and run config)
loader (DataLoader) – Training data is sampled from this loader.
optimizer (Optimizer) – The optimizer to use for weight update calculations.
loss_provider (PredictionLossProvider) – An instance of the
LossProvider
class for flexible loss calculation.epoch (int) – The current epoch.
- unpack_data(data, config, reverse=False, complete=False)
Extracts inputs and targets from a data blob.
- Parameters
data (VPData) – The given VPData data blob/dictionary containing frames and actions.
config (dict) – The current run configuration specifying how to extract the data from the given data blob.
reverse (bool) – If specified, reverses the input first
complete (bool) – If specified, input_frames will also contain the to-be-predicted frames (just like with NEEDS_COMPLETE_INPUT)
Returns: The specified amount of input/target frames as well as the actions. All inputs will come in the shape the model expects as input later.