Skip to content

Using torch.compile for speed-ups

Here are a few pointers on my experience with torch compilation to speed-up training.

I do not use specific packages. Here is an extract of my pyproject.toml:

[tool.pixi.project]
channels = ["nvidia/label/cuda-11.8.0", "nvidia", "pytorch", "conda-forge"]
platforms = ["linux-64"]

[tool.pixi.dependencies]
python = ">=3.11"
cuda = {version = "*", channel="nvidia/label/cuda-11.8.0"}
pytorch = {channel = "pytorch", version = "*"}
torchvision = {channel = "pytorch", version = ">=0.19.1"}
pytorch-cuda = {version = "11.8.*", channel="pytorch"}

First, the compilation process yields a lot of files, that may fill-up the default home on hpc systems.

You can use the TORCHINDUCTOR_CACHE_DIR environment variable to write those files to a different location:

export TORCHINDUCTOR_CACHE_DIR=$TMPDIR/$SLURM_JOBID/

You can call torch.compile to compile parts of your models, loss function, etc. Here is an example where I compile the LPIPS loss:

https://github.com/Evoland-Land-Monitoring-Evolution/tamrfsits/blob/32febf538a101e57327dd3faa1b086bea3552875/src/tamrfsits/core/cca.py#L187

Here I compile a generic torch.nn.module:

https://github.com/Evoland-Land-Monitoring-Evolution/tamrfsits/blob/32febf538a101e57327dd3faa1b086bea3552875/src/tamrfsits/components/datewise.py#L36

Note that if you compile parts of a model that is optimized, then the checkpoint dictionary changes: you can not lot directly a checkpoint made from compiled model to the same not compiled model.

Finally, you can compile specific functions. For instance here I compile the preprocessing function:

https://github.com/Evoland-Land-Monitoring-Evolution/tamrfsits/blob/32febf538a101e57327dd3faa1b086bea3552875/src/tamrfsits/tasks/base.py#L625

When using torch.compile, the first step can be rather slow, since this is where all the compilation happens. Wait for a few steps to measure speed-ups.

I mainly used this documentation: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

In my model, I did not compile the transformer modules, this was not working. Also, compiling make things static, so any time the input data structure changes shape, a new compilation occurs (until torch compile decides to give up). If we where to compile parts of the perceiver, we should focus on what is relatively stable in terms of shapes.