View Source Axon.LossScale (Axon v0.7.0)
Implementations of loss-scalers for use in mixed precision training.
Loss scaling is used to prevent underflow when using mixed precision during the model training process. Each loss-scale implementation here returns a 3-tuple of the functions:
{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15))
You can use these to scale/unscale loss and gradients as well as adjust the loss scale state.
Axon.Loop.trainer/3
builds loss-scaling in by default. You
can reference the Axon.Loop.train_step/3
implementation to
see how loss-scaling is applied in practice.
Summary
Functions
Implements dynamic loss-scale.
Implements identity loss-scale.
Implements static loss-scale.
Functions
Implements dynamic loss-scale.
Implements identity loss-scale.
Implements static loss-scale.