BUG: LightningModule does not handle NaNs
The y_target
tensors that I use during training sometimes contain NaN values (in measures and angles tokens). As a quick check, I regenerated a cached_dataset and still see approximately 10% of samples resulting in at least 1 NaN value in target and input tensors.
AFAICT, NaNs are not properly discarded by the LightningModule. In my trainings, this causes the loss and model weights to become indefinite after a few training steps (when training on full dataset).
As a first fix, we could drop the faulty samples by returning None
in training_step
and validation_step
(handled by Lightning). The check could be done in prepare_data_from_sample
. We could also refine this strategy and only drop faulty tokens (instead of the whole sample), but this is maybe not worth it for now. In any case, we will need to keep track of the number of skipped samples/tokens.
Also, there already exists a variable named is_valid
. But this variable is part of the model config (not the LightningModule), and applies the mask to the data. I think this is a separate concern.
I thus propose
- Adding a failing test to make sure the loss does not become NaN
- Discard samples that contain NaNs
- Log the number of discarded samples per epoch
Has someone encountered the same problem? Is there already a solution to this issue or should we fix it?