View Source Axon.MixedPrecision (Axon v0.7.0)
Utilities for creating mixed precision policies.
Mixed precision is useful for increasing model throughput at the possible
price of a small dip in accuracy. When creating a mixed precision policy,
you define the policy for params
, compute
, and output
.
The params
policy dictates what type parameters should be stored as
during training. The compute
policy dictates what type should be used
during intermediate computations in the model's forward pass. The output
policy dictates what type the model should output.
Here's an example of creating a mixed precision policy and applying it to a model:
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(64, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)
policy = Axon.MixedPrecision.create_policy(
params: {:f, 32},
compute: {:f, 16},
output: {:f, 32}
)
mp_model =
model
|> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])
The example above applies the mixed precision policy to every layer in
the model except Batch Normalization layers. The policy will cast parameters
and inputs to {:f, 16}
for intermediate computations in the model's forward
pass before casting the output back to {:f, 32}
.
Summary
Functions
Casts the given container according to the given policy and type.
Creates a mixed precision policy with the given options.
Functions
Casts the given container according to the given policy and type.
Examples
iex> policy = Axon.MixedPrecision.create_policy(params: {:f, 16})
iex> params = %{"dense" => %{"kernel" => Nx.tensor([1.0, 2.0, 3.0])}}
iex> params = Axon.MixedPrecision.cast(policy, params, :params)
iex> Nx.type(params["dense"]["kernel"])
{:f, 16}
iex> policy = Axon.MixedPrecision.create_policy(compute: {:bf, 16})
iex> value = Nx.tensor([1.0, 2.0, 3.0])
iex> value = Axon.MixedPrecision.cast(policy, value, :compute)
iex> Nx.type(value)
{:bf, 16}
iex> policy = Axon.MixedPrecision.create_policy(output: {:bf, 16})
iex> value = Nx.tensor([1.0, 2.0, 3.0])
iex> value = Axon.MixedPrecision.cast(policy, value, :output)
iex> Nx.type(value)
{:bf, 16}
Note that integers are never promoted to floats:
iex> policy = Axon.MixedPrecision.create_policy(output: {:f, 16})
iex> value = Nx.tensor([1, 2, 3], type: :s64)
iex> value = Axon.MixedPrecision.cast(policy, value, :params)
iex> Nx.type(value)
{:s, 64}
Creates a mixed precision policy with the given options.
The default policy nil
dictates that no casting will be done.
Options
params
- parameter precision policy. Defaults tonil
compute
- compute precision policy. Defaults tonil
output
- output precision policy. Defaults tonil
Examples
iex> Axon.MixedPrecision.create_policy(params: {:f, 16}, output: {:f, 16})
#Axon.MixedPrecision.Policy<p=f16 o=f16>
iex> Axon.MixedPrecision.create_policy(compute: {:bf, 16})
#Axon.MixedPrecision.Policy<c=bf16>
iex> Axon.MixedPrecision.create_policy()
#Axon.MixedPrecision.Policy<>