Introduction
Video prediction (‘VP’) is the task of predicting future frames given some context frames.
Like with most Computer Vision sub-domains, scientific contributions in this field exhibit a high variance in the following aspects:
Training protocol (dataset usage, when to backprop, value ranges etc.)
Technical details of model implementation (deep learning framework, package dependencies etc.)
Benchmark selection and execution (this includes the choice of dataset, number of context/predicted frames, skipping frames in the observed sequences etc.)
Evaluation protocol (metrics chosen, variations in implementation/reduction modes, different ways of creating visualizations etc.)
Furthermore, while many contributors nowadays do share their code, seemingly minor missing parts such as dataloaders etc. make it much harder to assess, compare and improve existing models.
This repo aims at providing a suite that facilitates scientific work in the subfield, providing standardized yet customizable solutions for the aspects mentioned above. This way, validating existing VP models and creating new ones hopefully becomes much less tedious.
Installation
Requires ``pip`` and ``python >= 3.6`` (code is tested with version ``3.8``).
From PyPi
``` pip install vp-suite ```From source
``` pip install git+https://github.com/Flunzmas/vp-suite.git ```If you want to contribute
``` git clone https://github.com/Flunzmas/vp-suite.git cd vp-suite pip install -e .[dev] ```If you want to build docs
``` git clone https://github.com/Flunzmas/vp-suite.git cd vp-suite pip install -e .[doc] ```Usage
Changing save location
When using this package for the first time, the save location for datasets, models and logs is set to `Training models
```python from vp_suite import VPSuite # 1. Set up the VP Suite. suite = VPSuite() # 2. Load one of the provided datasets. # They will be downloaded automatically if no downloaded data is found. suite.load_dataset("MM") # load moving MNIST dataset from default location # 3. Create a video prediction model. suite.create_model('convlstm-shi') # create a ConvLSTM-Based Prediction Model. # 4. Run the training loop, optionally providing custom configuration. suite.train(lr=2e-4, epochs=100) ``` This code snippet will train the model, log training progress to your [Weights & Biases](https://wandb.ai) account, save model checkpoints on improvement and generate and save prediction visualizations.Evaluating models
```python from vp_suite import VPSuite # 1. Set up the VP Suite. suite = VPSuite() # 2. Load one of the provided datasets in test mode. # They will be downloaded automatically if no downloaded data is found. suite.load_dataset("MM", split="test") # load moving MNIST dataset from default location # 3. Get the filepaths to the models you'd like to test and load the models model_dirs = ["out/model_foo/", "out/model_bar/"] for model_dir in model_dirs: suite.load_model(model_dir, ckpt_name="best_model.pth") # 4. Test the loaded models on the loaded test sets. suite.test(context_frames=5, pred_frames=10) ``` This code will evaluate the loaded models on the loaded dataset (its test portion, if avaliable), creating detailed summaries of prediction performance across a customizable set of metrics. The results as well as prediction visualizations are saved and logged to [Weights & Biases](https://wandb.ai). _Note 1: If the specified evaluation protocol or the loaded dataset is incompatible with one of the models, this will raise an error with an explanation._ _Note 2: By default, a [CopyLastFrame](https://github.com/AIS-Bonn/vp-suite/blob/main/vp_suite/models/model_copy_last_frame.py) baseline is also loaded and tested with the other models._Hyperparameter Optimization
This package uses [optuna](https://github.com/optuna/optuna) to provide hyperparameter optimization functionalities. The following snippet provides a full example: ```python import json from vp_suite import VPSuite from vp_suite.defaults import SETTINGS suite = VPSuite() suite.load_dataset(dataset="KTH") # select dataset of choice suite.create_model(model_id="lstm") # select model of choice with open(str((SETTINGS.PKG_RESOURCES / "optuna_example_config.json").resolve()), 'r') as cfg_file: optuna_cfg = json.load(cfg_file) # optuna_cfg specifies the parameters' search intervals and scales; modify as you wish. suite.hyperopt(optuna_cfg, n_trials=30, epochs=10) ``` This code e.g. will run 30 training loops (called _trials_ by optuna), producing a trained model for each hyperparameter configuration and writing the hyperparameter configuration of the best performing run to the console. _Note 1: For hyperopt, visualization, logging and model checkpointing is minimized to reduce IO strain._ _Note 2: Despite optuna's trial pruning capabilities, running a high number of trials might still take a lot of time. In that case, consider e.g. reducing the number of training epochs._ Use `no_wandb=True`/`no_vis=True` if you want to log outputs to the console instead/not generate and save visualizations.Notes:
Use
VPSuite.list_available_models()
andVPSuite.list_available_datasets()
to get an overview of which models and datasets are currently covered by the framework.All training, testing and hyperparametrization calls can be heavily configured (adjusting training hyperparameters, logging behavior etc, …). For a comprehensive list of all adjustable run configuration parameters see the documentation of the
vp_suite.defaults
package.
Customization
This package is designed with quick extensibility in mind. See the sections below for how to add new components (models, model blocks, datasets or measures).
New Models
1. Create a file `New Model Blocks
1. Create a file `New Datasets
1. Create a file `New Measures (losses and/or metrics)
1. Create a new file `Notes:
If you omit the docstring for a particular attribute/method/field, the docstring of the base class is used for documentation.
If implementing components that originate from publications/public repositories, please override the corresponding constants to specify the source! Additionally, if you want to write automated tests checking implementation equality, have a look at how
tests/test_impl_match.py
fetches the tests oftests/test_impl_match/
and executes these tests.Basic unit tests for models, datasets and measures are executed on all registered models - you don’t need to write such basic tests for your custom components! Same applies for documentation: The tables that list available components are filled automatically.
Contributing
This project is always open to extension! It grows especially powerful with more models and datasets, so if you’ve made your code work on custom models/datasets/metrics/etc., feel free to submit a merge request!
Other kinds of contributions are also very welcome - just check the open issues on the tracker or open up a new issue there.
Unit Testing
When submitting a merge request, please make sure all tests run through (execute from root folder):
python -m pytest --runslow --cov=vp_suite -rs
Note: this is the easiest way to run all tests `without import hassles <https://docs.pytest.org/en/latest/explanation/pythonpath.html#invoking-pytest-versus-python-m-pytest>`_. You will need to have ``vp-suite`` installed in development mode, though (`see here <#installation>`_).
API Documentation
The official API documentation is updated automatically upon push to the main branch. If you want to build the documentation locally, make sure you’ve installed the package accordingly and execute the following:
cd docs/
bash assemble_docs.sh
Citing
Please consider citing if you find our findings or our repository helpful.
@article{karapetyan_VideoPrediction_2022,
title = {Video Prediction at Multiple Scales with Hierarchical Recurrent Networks},
author = {Karapetyan, Ani and Villar-Corrales, Angel and Boltres, Andreas and Behnke, Sven},
journal={arXiv preprint arXiv:2203.09303},
year={2022}
}
Acknowledgements
Project structure is inspired by segmentation_models.pytorch.
Sphinx-autodoc templates are inspired by the QNET repository.
All other sources are acknowledged in the documentation of the respective point of usage (to the best of our knowledge).
License
This project comes with an MIT License, except for the following components:
Module
vp_suite.measure.fvd.pytorch_i3d
(Apache 2.0 License, taken and modified from here)
Disclaimer
I do not host or distribute any dataset. For all provided dataset functionality, I trust you have the permission to download and use the respective data.