View Source Axon.ModelState (Axon v0.8.1)

Model State Data Structure.

This data structure represents all the state needed for a model to perform inference.

Summary

Functions

Returns an empty model state.

Freezes parameters and state in the given model state using the given mask.

Returns the frozen parameters in the given model state.

Returns the frozen state in the given model state.

Merges 2 states with function.

Returns a new model state struct from the given parameter map.

Ties a parameter to another parameter, enabling weight sharing.

Returns the trainable parameters in the given model state.

Returns the trainable state in the given model state.

Unfreezes parameters and state in the given model state using the given mask.

Functions

Returns an empty model state.

Link to this function

freeze(model_state, mask \\ fn _ -> true end)

View Source

Freezes parameters and state in the given model state using the given mask.

The mask is an arity 1 function which takes the access path to the leaf parameter and returns true if the parameter should be frozen or false otherwise. With this, you can construct flexible masking policies:

fn
  ["dense_" <> n, "kernel"] -> String.to_integer(n) < 3
  _ -> false
end

The default mask returns true for all paths, and is equivalent to freezing the entire model.

Link to this function

frozen_parameters(model_state)

View Source

Returns the frozen parameters in the given model state.

Link to this function

frozen_state(model_state)

View Source

Returns the frozen state in the given model state.

Link to this function

merge(lhs, model_state, fun)

View Source

Merges 2 states with function.

Returns a new model state struct from the given parameter map.

Link to this function

tie(model_state, destination, source, opts \\ [])

View Source

Ties a parameter to another parameter, enabling weight sharing.

The destination parameter will reference the source parameter's tensor, optionally applying a transformation. Both destination and source are access paths (lists of strings) into the model state data.

Options

  • :transform - a function to transform the source tensor before use at the destination. For example, &Nx.transpose/1 for tying an embedding layer to an output projection.

Examples

# Tie output projection to embedding weights (transposed)
model_state = Axon.ModelState.tie(
  model_state,
  ["output", "kernel"],
  ["embed", "kernel"],
  transform: &Nx.transpose/1
)
Link to this function

trainable_parameters(model_state)

View Source

Returns the trainable parameters in the given model state.

Link to this function

trainable_state(model_state)

View Source

Returns the trainable state in the given model state.

Link to this function

unfreeze(model_state, mask \\ fn _ -> true end)

View Source

Unfreezes parameters and state in the given model state using the given mask.

The mask is an arity 1 function which takes the access path to the leaf parameter and returns true if the parameter should be unfrozen or false otherwise. With this, you can construct flexible masking policies:

fn
  ["dense_" <> n, "kernel"] -> n < 3
  _ -> false
end

The default mask returns true for all paths, and is equivalent to unfreezing the entire model.

Link to this function

update(model_state, updated_parameters, updated_state \\ %{})

View Source

Updates the given model state.