View Source Axon (Axon v0.6.1)

A high-level interface for creating neural network models.

Axon is built entirely on top of Nx numerical definitions, so every neural network can be JIT or AOT compiled using any Nx compiler, or even transformed into high-level neural network formats like TensorFlow Lite and ONNX.

For a more in-depth overview of Axon, refer to the Guides.

Model Creation

All Axon models start with an input layer, optionally specifying the expected shape of the input data:

input = Axon.input("input", shape: {nil, 784})

Notice you can specify some dimensions as nil, indicating that the dimension size will be filled in at model runtime. You can then compose inputs with other layers:

model =
  input
  |> Axon.dense(128, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.8)
  |> Axon.dense(64)
  |> Axon.tanh()
  |> Axon.dense(10)
  |> Axon.activation(:softmax)

You can inspect the model for a nice summary:

IO.inspect(model)

#Axon<
  inputs: %{"input" => {nil, 784}}
  outputs: "softmax_0"
  nodes: 9
>

Or use the Axon.Display module to see more in-depth summaries:

Axon.Display.as_table(model, Nx.template({1, 784}, :f32)) |> IO.puts

+----------------------------------------------------------------------------------------------------------------+
|                                                     Model                                                      |
+=======================================+=============+==============+===================+=======================+
| Layer                                 | Input Shape | Output Shape | Options           | Parameters            |
+=======================================+=============+==============+===================+=======================+
| input ( input )                       | []          | {1, 784}     | shape: {nil, 784} |                       |
|                                       |             |              | optional: false   |                       |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["input"] )            | [{1, 784}]  | {1, 128}     |                   | kernel: f32[784][128] |
|                                       |             |              |                   | bias: f32[128]        |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| relu_0 ( relu["dense_0"] )            | [{1, 128}]  | {1, 128}     |                   |                       |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| batch_norm_0 ( batch_norm["relu_0"] ) | [{1, 128}]  | {1, 128}     | epsilon: 1.0e-5   | gamma: f32[128]       |
|                                       |             |              | channel_index: 1  | beta: f32[128]        |
|                                       |             |              | momentum: 0.1     | mean: f32[128]        |
|                                       |             |              |                   | var: f32[128]         |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dropout_0 ( dropout["batch_norm_0"] ) | [{1, 128}]  | {1, 128}     | rate: 0.8         |                       |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["dropout_0"] )        | [{1, 128}]  | {1, 64}      |                   | kernel: f32[128][64]  |
|                                       |             |              |                   | bias: f32[64]         |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| tanh_0 ( tanh["dense_1"] )            | [{1, 64}]   | {1, 64}      |                   |                       |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_2 ( dense["tanh_0"] )           | [{1, 64}]   | {1, 10}      |                   | kernel: f32[64][10]   |
|                                       |             |              |                   | bias: f32[10]         |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_2"] )      | [{1, 10}]   | {1, 10}      |                   |                       |
+---------------------------------------+-------------+--------------+-------------------+-----------------------+

Multiple Inputs

Creating a model with multiple inputs is as easy as declaring an additional input in your Axon graph. Every input layer present in the final Axon graph will be required to be passed as input at the time of model execution.

inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})

# Both inputs will be used
model1 = Axon.add(inp1, inp2)

# Only inp2 will be used
model2 = Axon.add(inp2, inp2)

Axon graphs are immutable, which means composing and manipulating an Axon graph creates an entirely new graph. Additionally, layer names are lazily generated at model execution time. To avoid non-deterministic input orderings and names, Axon requires each input to have a unique binary identifier. You can then reference inputs by name when passing to models at execution time:

inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})

model1 = Axon.add(inp1, inp2)

{init_fn, predict_fn} = Axon.build(model1)

params1 = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
# Inputs are referenced by name
predict_fn.(params1, %{"input_0" => x, "input_1" => y})

Multiple Outputs

Nx offers robust container support which is extended to Axon. Axon allows you to wrap any valid Nx container in a layer. Containers are most commonly used to structure outputs:

inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container(%{foo: inp1, bar: inp2})

Containers can be arbitrarily nested:

inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container({%{foo: {inp1, %{bar: inp2}}}})

You can even use custom structs which implement the container protocol:

inp1 = Axon.input("input_0", shape: {nil, 1})
inp2 = Axon.input("input_1", shape: {nil, 1})
model = Axon.container(%MyStruct{foo: inp1, bar: inp2})

Custom Layers

If you find that Axon's built-in layers are insufficient for your needs, you can create your own using the custom layer API. All of Axon's built-in layers (aside from special ones such as input, constant, and container) make use of this same API.

Axon layers are really just placeholders for Nx computations with trainable parameters and possibly state. To define a custom layer, you just need to define a defn implementation:

defn my_layer(x, weight, _opts \\ []) do
  Nx.atan2(x, weight)
end

Notice the only stipulation is that your custom layer implementation must accept at least 1 input and a list of options. At execution time, every layer will be passed a :mode option which can be used to control behavior at training and inference time.

Inputs to your custom layer can be either Axon graph inputs or trainable parameters. You can pass Axon graph inputs as-is to a custom layer. To declare trainable parameters, use Axon.param/3:

weight = Axon.param("weight", param_shape)

To create a custom layer, you "wrap" your implementation and inputs into a layer using Axon.layer. You'll notice the API mirrors Elixir's apply:

def atan2_layer(%Axon{} = input) do
  weight = Axon.param("weight", param_shape)
  Axon.layer(&my_layer/3, [input, weight])
end

Model Execution

Under the hood, Axon models are represented as Elixir structs. You can initialize and apply models by building or compiling them with Axon.build/2 or Axon.compile/4 and then calling the produced initialization and predict functions:

{init_fn, predict_fn} = Axon.build(model)

params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
predict_fn.(params, inputs)

You may either set the default JIT compiler or backend globally, or pass a specific compiler to Axon.build/2:

EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])

{init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train)

params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{})
predict_fn.(params, inputs)

predict_fn by default runs in inference mode, which performs certain optimizations and removes layers such as dropout layers. If constructing a training step using Axon.predict/4 or Axon.build/2, be sure to specify mode: :train.

Model Training

Combining the Axon model creation API with the optimization and training APIs, you can create and train neural networks with ease:

model =
  Axon.input("input_0", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.layer_norm()
  |> Axon.dropout()
  |> Axon.dense(10, activation: :softmax)

IO.inspect model

model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.005))
  |> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA)

See Polaris.Updates and Axon.Loop for a more in-depth treatment of model optimization and model training.

Using with Nx.Serving

When deploying an Axon model to production, you usually want to batch multiple prediction requests and run the inference for all of them at once. Conveniently, Nx already has an abstraction for this task in the form of Nx.Serving. Here's how you could define a serving for an Axon model:

def build_serving() do
  # Configuration
  batch_size = 4
  defn_options = [compiler: EXLA]

  Nx.Serving.new(
    # This function runs on the serving startup
    fn ->
      # Build the Axon model and load params (usually from file)
      model = build_model()
      params = load_params()

      # Build the prediction defn function
      {_init_fun, predict_fun} = Axon.build(model)

      inputs_template = %{"pixel_values" => Nx.template({batch_size, 224, 224, 3}, :f32)}
      template_args = [Nx.to_template(params), inputs_template]

      # Compile the prediction function upfront for the configured batch_size
      predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options)

      # The returned function is called for every accumulated batch
      fn inputs ->
        inputs = Nx.Batch.pad(inputs, batch_size - inputs.size)
        predict_fun.(params, inputs)
      end
    end,
    batch_size: batch_size
  )
end

Then you would start the serving server as part of your application's supervision tree:

children = [
  ...,
  {Nx.Serving, serving: build_serving(), name: MyApp.Serving, batch_timeout: 100}
]

With that in place, you can now ask serving for predictions all across your application (controllers, live views, async jobs, etc.). Having a tensor input you would do:

inputs = %{"pixel_values" => ...}
batch = Nx.Batch.concatenate([inputs])
result = Nx.Serving.batched_run(MyApp.Serving, batch)

Usually you also want to do pre/post-processing of the model input/output. You could make those preparations directly before/after Nx.Serving.batched_run/2, however you can also make use of Nx.Serving.client_preprocessing/2 and Nx.Serving.client_postprocessing/2 to encapsulate that logic as part of the serving.

Summary

Layers: Special

Returns a function which represents a self-contained re-usable block of operations in a neural network. All parameters in the block are shared between every usage of the block.

Adds a constant layer to the network.

Adds a container layer to the network.

Adds an input layer to the network.

Custom Axon layer with given inputs.

Wraps an Axon model into a namespace.

Applies the given Nx expression to the input.

Wraps an Axon model in an optional node.

Trainable Axon parameter used to create custom layers.

Adds a stack columns layer to the network.

Layers: Activation

Adds an activation layer to the network.

Adds a Continuously-differentiable exponential linear unit activation layer to the network.

Adds an Exponential linear unit activation layer to the network.

Adds an Exponential activation layer to the network.

Adds a Gaussian error linear unit activation layer to the network.

Adds a Hard sigmoid activation layer to the network.

Adds a Hard sigmoid weighted linear unit activation layer to the network.

Adds a Hard hyperbolic tangent activation layer to the network.

Adds a Leaky rectified linear unit activation layer to the network.

Adds a Linear activation layer to the network.

Adds a Log-sigmoid activation layer to the network.

Adds a Log-softmax activation layer to the network.

Adds a Log-sumexp activation layer to the network.

Adds a Mish activation layer to the network.

Adds a Rectified linear unit 6 activation layer to the network.

Adds a Rectified linear unit activation layer to the network.

Adds a Scaled exponential linear unit activation layer to the network.

Adds a Sigmoid activation layer to the network.

Adds a Sigmoid weighted linear unit activation layer to the network.

Adds a Softmax activation layer to the network.

Adds a Softplus activation layer to the network.

Adds a Softsign activation layer to the network.

Adds a Hyperbolic tangent activation layer to the network.

Layers: Linear

Adds a bias layer to the network.

Adds a bilinear layer to the network.

Adds a dense layer to the network.

Adds an embedding layer to the network.

Layers: Convolution

Adds a convolution layer to the network.

Adds a transposed convolution layer to the network.

Adds a depthwise convolution layer to the network.

Adds a depthwise separable 2-dimensional convolution to the network.

Adds a depthwise separable 3-dimensional convolution to the network.

Layers: Dropout

Adds an Alpha dropout layer to the network.

Adds a Dropout layer to the network.

Adds a Feature alpha dropout layer to the network.

Adds a Spatial dropout layer to the network.

Layers: Pooling

Adds an Adaptive average pool layer to the network.

Adds an Adaptive power average pool layer to the network.

Adds an Adaptive max pool layer to the network.

Adds an Average pool layer to the network.

Adds a Global average pool layer to the network.

Adds a Global LP pool layer to the network.

Adds a Global max pool layer to the network.

Adds a Power average pool layer to the network.

Adds a Max pool layer to the network.

Layers: Normalization

Adds a Batch normalization layer to the network.

Adds a group normalization layer to the network.

Adds an Instance normalization layer to the network.

Adds a Layer normalization layer to the network.

Layers: Recurrent

Adds a convolutional long short-term memory (LSTM) layer to the network with a random initial hidden state.

Adds a convolutional long short-term memory (LSTM) layer to the network with the given initial hidden state..

Adds a gated recurrent unit (GRU) layer to the network with a random initial hidden state.

Adds a gated recurrent unit (GRU) layer to the network with the given initial hidden state.

Adds a long short-term memory (LSTM) layer to the network with a random initial hidden state.

Adds a long short-term memory (LSTM) layer to the network with the given initial hidden state.

Computes a sequence mask according to the given EOS token.

Layers: Combinators

Adds a add layer to the network.

Adds a concatenate layer to the network.

Adds a conditional layer which conditionally executes true_graph or false_graph based on the condition cond_fn at runtime.

Adds a multiply layer to the network.

Splits input graph into a container of n input graphs along the given axis.

Adds a subtract layer to the network.

Layers: Shape

Adds a flatten layer to the network.

Adds a pad layer to the network.

Adds a reshape layer to the network.

Adds a resize layer to the network.

Adds a transpose layer to the network.

Model

Builds the given model to {init_fn, predict_fn}.

Compiles the given model to {init_fn, predict_fn}.

Deserializes serialized model and parameters into a {model, params} tuple.

Freezes parameters returned from the given function or predicate.

Builds and runs the given Axon model with params and input.

Serializes a model and its parameters for persisting models to disk or elsewhere.

Unfreezes parameters returned from the given function or predicate.

Model: Manipulation

Returns information about a model's inputs.

Returns a map of model op counts for each unique operation in a model by their given :op_name.

Returns a node's immediate input options.

Returns a model's output shape from the given input template.

Returns a node's immediate parameters.

Traverses graph nodes in order, applying fun to each node exactly once to return a transformed node in its place(s) in the graph.

Pops the top node off of the graph.

Traverses graph nodes in order, applying fun to each node exactly once to return a transformed node in its place(s) in the graph.

Sets a node's immediate options to the given input options.

Sets a node's immediate parameters to the given parameters.

Model: Debugging

Attaches a hook to the given Axon model.

Compiles and returns the given model's backward function expression with respect to the given loss function.

Compiles and returns the given model's forward function expression with the given options.

Compiles and returns the given model's init function expression with the given options.

Functions

Applies the given forward function bidirectionally and merges the results with the given merge function.

Adds a blur pooling layer to the network.

Layers: Special

Returns a function which represents a self-contained re-usable block of operations in a neural network. All parameters in the block are shared between every usage of the block.

This returns an arity-1 function which accepts a list of inputs which are forwarded to fun. This is most often used in situations where you wish to re-use parameters in a block:

reused_dense = Axon.block(&Axon.dense(&1, 32))

Everytime reused_dense is invoked, it re-uses the same parameters:

input = Axon.input("features")
# unique parameters
x1 = Axon.dense(input, 32)
# unique parameters
x2 = reused_dense.(x1)
# parameters shared
x3 = reused_dense.(x2)

Subgraphs in blocks can be arbitrarily complex:

reused_block = Axon.block(fn x ->
  x
  |> Axon.dense(32)
  |> Axon.dense(64)
  |> Axon.dense(32)
end)

Blocks can also have multiple inputs, you can invoke a block with multiple inputs by passing a list of arguments:

reused_block = Axon.block(fn x, y, z ->
  x = Axon.dense(x, 32)
  y = Axon.dense(y, 32)
  z = Axon.dense(z, 32)

  Axon.add([x, y, z])
end)

# invoke with a list
reused_block.([x, y, z])

Blocks prefix subgraph parameters with their name and a dot. As with other Axon layers, if a name is not explicitly provided, one will be dynamically generated.

Link to this function

constant(tensor, opts \\ [])

View Source

Adds a constant layer to the network.

Constant layers encapsulate Nx tensors in an Axon layer for ease of use with other Axon layers. They can be used interchangeably with other Axon layers:

inp = Axon.input("input", shape: {nil, 32})
my_constant = Axon.constant(Nx.iota({1, 32}))
model = Axon.add(inp, my_constant)

Constant layers will be cast according to the mixed precision policy. If it's important for your constant to retain it's type during the computation, you will need to set the mixed precision policy to ignore constant layers.

Options

  • :name - layer name.
Link to this function

container(container, opts \\ [])

View Source

Adds a container layer to the network.

In certain cases you may want your model to have multiple outputs. In order to make this work, you must "join" the outputs into an Axon layer using this function for use in initialization and inference later on.

The given container can be any valid Axon Nx container.

Options

  • :name - layer name.

Examples

iex> inp1 = Axon.input("input_0", shape: {nil, 1})
iex> inp2 = Axon.input("input_1", shape: {nil, 2})
iex> model = Axon.container(%{a: inp1, b: inp2})
iex> %{a: a, b: b} = Axon.predict(model, %{}, %{
...>    "input_0" => Nx.tensor([[1.0]]),
...>    "input_1" => Nx.tensor([[1.0, 2.0]])
...> })
iex> a
#Nx.Tensor<
  f32[1][1]
  [
    [1.0]
  ]
>
iex> b
#Nx.Tensor<
  f32[1][2]
  [
    [1.0, 2.0]
  ]
>

Adds an input layer to the network.

Input layers specify a model's inputs. Input layers are always the root layers of the neural network.

You must specify the input layers name, which will be used to uniquely identify it in the case of multiple inputs.

Options

  • :shape - the expected input shape, use nil for dimensions of a dynamic size.

  • :optional - if true, the input may be omitted when using the model. This needs to be handled in one of the subsequent layers. See optional/2 for more details.

Link to this function

layer(op, inputs, opts \\ [])

View Source

Custom Axon layer with given inputs.

Inputs may be other Axon layers or trainable parameters created with Axon.param. At inference time, op will be applied with inputs in specified order and an additional opts parameter which specifies inference options. All options passed to layer are forwarded to inference function except:

  • :name - layer name.

  • :op_name - layer operation for inspection and building parameter map.

  • :mode - if the layer should run only on :inference or :train. Defaults to :both

  • :global_options - a list of global option names that this layer supports. Global options passed to build/2 will be forwarded to the layer, as long as they are declared

Note this means your layer should not use these as input options, as they will always be dropped during inference compilation.

Axon's compiler will additionally forward the following options to every layer at inference time:

  • :mode - :inference or :train. To control layer behavior based on inference or train time.

op is a function of the form:

fun = fn input, weight, bias, _opts ->
  input * weight + bias
end

Wraps an Axon model into a namespace.

A namespace is a part of an Axon model which is meant to be a self-contained collection of Axon layers. Namespaces are guaranteed to always generate with the same internal layer names and can be re-used universally across models.

Namespaces are most useful for containing large collections of layers and offering a straightforward means for accessing the parameters of individual model components. A common application of namespaces is to use them in with a pre-trained model for fine-tuning:

{base, resnet_params} = resnet()
base = base |> Axon.namespace("resnet")

model = base |> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model)

init_fn.(Nx.template({1, 3, 224, 224}, {:f, 32}), %{"resnset" => resnet_params})

Notice you can use init_fn in conjunction with namespaces to specify which portion of a model you'd like to initialize from a fixed starting point.

Namespaces have fixed names, which means it's easy to run into namespace collisions. Re-using namespaces, re-using inner parts of a namespace, and attempting to share layers between namespaces are still sharp edges in namespace usage.

Link to this function

nx(input, fun, opts \\ [])

View Source

Applies the given Nx expression to the input.

Nx layers are meant for quick applications of functions without trainable parameters. For example, they are useful for applying functions which apply accessors to containers:

model = Axon.container({foo, bar})
Axon.nx(model, &elem(&1, 0))

Options

  • :name - layer name.

Wraps an Axon model in an optional node.

By default, when an optional input is missing, all subsequent layers are nullified. For example, consider this model:

values = Axon.input("values")
mask = Axon.input("mask", optional: true)

model =
  values
  |> Axon.dense(10)
  |> Axon.multiply(mask)
  |> Axon.dense(1)
  |> Axon.sigmoid()

In case the mask is not provided, the input node will resolve to %Axon.None{} and so will all the layers that depend on it. By using optional/2 a layer may opt-in to receive %Axon.None{}. To fix our example, we could define a custom layer to apply the mask only when present

def apply_optional_mask(%Axon{} = x, %Axon{} = mask) do
  Axon.layer(
    fn x, mask, _opts ->
      case mask do
        %Axon.None{} -> x
        mask -> Nx.multiply(x, mask)
      end
    end,
    [x, Axon.optional(mask)]
  )
end

# ...

model =
  values
  |> Axon.dense(10)
  |> apply_optional_mask(mask)
  |> Axon.dense(1)
  |> Axon.sigmoid()

Options

  • :name - layer name.
Link to this function

param(name, shape, opts \\ [])

View Source

Trainable Axon parameter used to create custom layers.

Parameters are specified in usages of Axon.layer and will be automatically initialized and used in subsequent applications of Axon models.

You may specify the parameter shape as either a static shape or as function of the inputs to the given layer. If you specify the parameter shape as a function, it will be given the

Options

  • :initializer - parameter initializer. Defaults to :glorot_uniform.
Link to this function

stack_columns(x, opts \\ [])

View Source

Adds a stack columns layer to the network.

A stack columns layer is designed to be used with Nx.LazyContainer data structures like Explorer DataFrames. Given an input which is a DataFrame, stack_columns/2 will stack the columns in each row to create a single vector.

You may optionally specify :ignore to ignore certain columns in the container.

Options

  • :name - layer name.

  • :ignore - keys to ignore when stacking.

Layers: Activation

Link to this function

activation(x, activation, opts \\ [])

View Source

Adds an activation layer to the network.

Activation layers are element-wise functions typically called after the output of another layer.

Options

  • :name - layer name.

Adds a Continuously-differentiable exponential linear unit activation layer to the network.

See Axon.Activations.celu/1 for more details.

Options

  • :name - layer name.

Adds an Exponential linear unit activation layer to the network.

See Axon.Activations.elu/1 for more details.

Options

  • :name - layer name.

Adds an Exponential activation layer to the network.

See Axon.Activations.exp/1 for more details.

Options

  • :name - layer name.

Adds a Gaussian error linear unit activation layer to the network.

See Axon.Activations.gelu/1 for more details.

Options

  • :name - layer name.
Link to this function

hard_sigmoid(x, opts \\ [])

View Source

Adds a Hard sigmoid activation layer to the network.

See Axon.Activations.hard_sigmoid/1 for more details.

Options

  • :name - layer name.
Link to this function

hard_silu(x, opts \\ [])

View Source

Adds a Hard sigmoid weighted linear unit activation layer to the network.

See Axon.Activations.hard_silu/1 for more details.

Options

  • :name - layer name.
Link to this function

hard_tanh(x, opts \\ [])

View Source

Adds a Hard hyperbolic tangent activation layer to the network.

See Axon.Activations.hard_tanh/1 for more details.

Options

  • :name - layer name.
Link to this function

leaky_relu(x, opts \\ [])

View Source

Adds a Leaky rectified linear unit activation layer to the network.

See Axon.Activations.leaky_relu/1 for more details.

Options

  • :name - layer name.

Adds a Linear activation layer to the network.

See Axon.Activations.linear/1 for more details.

Options

  • :name - layer name.
Link to this function

log_sigmoid(x, opts \\ [])

View Source

Adds a Log-sigmoid activation layer to the network.

See Axon.Activations.log_sigmoid/1 for more details.

Options

  • :name - layer name.
Link to this function

log_softmax(x, opts \\ [])

View Source

Adds a Log-softmax activation layer to the network.

See Axon.Activations.log_softmax/1 for more details.

Options

  • :name - layer name.
Link to this function

log_sumexp(x, opts \\ [])

View Source

Adds a Log-sumexp activation layer to the network.

See Axon.Activations.log_sumexp/1 for more details.

Options

  • :name - layer name.

Adds a Mish activation layer to the network.

See Axon.Activations.mish/1 for more details.

Options

  • :name - layer name.

Adds a Rectified linear unit 6 activation layer to the network.

See Axon.Activations.relu6/1 for more details.

Options

  • :name - layer name.

Adds a Rectified linear unit activation layer to the network.

See Axon.Activations.relu/1 for more details.

Options

  • :name - layer name.

Adds a Scaled exponential linear unit activation layer to the network.

See Axon.Activations.selu/1 for more details.

Options

  • :name - layer name.

Adds a Sigmoid activation layer to the network.

See Axon.Activations.sigmoid/1 for more details.

Options

  • :name - layer name.

Adds a Sigmoid weighted linear unit activation layer to the network.

See Axon.Activations.silu/1 for more details.

Options

  • :name - layer name.

Adds a Softmax activation layer to the network.

See Axon.Activations.softmax/1 for more details.

Options

  • :name - layer name.

Adds a Softplus activation layer to the network.

See Axon.Activations.softplus/1 for more details.

Options

  • :name - layer name.

Adds a Softsign activation layer to the network.

See Axon.Activations.softsign/1 for more details.

Options

  • :name - layer name.

Adds a Hyperbolic tangent activation layer to the network.

See Axon.Activations.tanh/1 for more details.

Options

  • :name - layer name.

Layers: Linear

Adds a bias layer to the network.

A bias layer simply adds a trainable bias to an input.

Options

  • :name - layer name.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

Link to this function

bilinear(input1, input2, units, opts \\ [])

View Source

Adds a bilinear layer to the network.

The bilinear layer implements:

output = activation(dot(dot(input1, kernel), input2) + bias)

where activation is given by the :activation option and both kernel and bias are layer parameters. units specifies the number of output units.

All dimensions but the last of input1 and input2 must match. The batch sizes of both inputs must also match or at least one must be nil. Inferred output batch size coerces to the strictest input batch size.

Compiles to Axon.Layers.bilinear/5.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true.

Link to this function

dense(x, units, opts \\ [])

View Source

Adds a dense layer to the network.

The dense layer implements:

output = activation(dot(input, kernel) + bias)

where activation is given by the :activation option and both kernel and bias are layer parameters. units specifies the number of output units.

Compiles to Axon.Layers.dense/4.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true.

Link to this function

embedding(x, vocab_size, embedding_size, opts \\ [])

View Source

Adds an embedding layer to the network.

An embedding layer initializes a kernel of shape {vocab_size, embedding_size} which acts as a lookup table for sequences of discrete tokens (e.g. sentences). Embeddings are typically used to obtain a dense representation of a sparse input space.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :uniform.

Layers: Convolution

Link to this function

conv(x, units, opts \\ [])

View Source

Adds a convolution layer to the network.

The convolution layer implements a general dimensional convolutional layer - which convolves a kernel over the input to produce an output.

Compiles to Axon.Layers.conv/4.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to 1.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :input_dilation - dilation to apply to input. Defaults to 1.

  • :kernel_dilation - dilation to apply to kernel. Defaults to 1.

  • :feature_group_size - feature group size for convolution. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Link to this function

conv_transpose(x, units, opts \\ [])

View Source

Adds a transposed convolution layer to the network.

The transposed convolution layer is sometimes referred to as a fractionally strided convolution or (incorrectly) as a deconvolution.

Compiles to Axon.Layers.conv_transpose/4.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to 1.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :kernel_dilation - dilation to apply to kernel. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Link to this function

depthwise_conv(x, channel_multiplier, opts \\ [])

View Source

Adds a depthwise convolution layer to the network.

The depthwise convolution layer implements a general dimensional depthwise convolution - which is a convolution where the feature group size is equal to the number of input channels.

Channel multiplier grows the input channels by the given factor. An input factor of 1 means the output channels are the same as the input channels.

Compiles to Axon.Layers.depthwise_conv/4.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to 1.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :input_dilation - dilation to apply to input. Defaults to 1.

  • :kernel_dilation - dilation to apply to kernel. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Link to this function

separable_conv2d(x, channel_multiplier, opts \\ [])

View Source

Adds a depthwise separable 2-dimensional convolution to the network.

Depthwise separable convolutions break the kernel into kernels for each dimension of the input and perform a depthwise conv over the input with each kernel.

Compiles to Axon.Layers.separable_conv2d/6.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to 1.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :input_dilation - dilation to apply to input. Defaults to 1.

  • :kernel_dilation - dilation to apply to kernel. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Link to this function

separable_conv3d(x, channel_multiplier, opts \\ [])

View Source

Adds a depthwise separable 3-dimensional convolution to the network.

Depthwise separable convolutions break the kernel into kernels for each dimension of the input and perform a depthwise conv over the input with each kernel.

Compiles to Axon.Layers.separable_conv3d/8.

Options

  • :name - layer name.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros

  • :activation - element-wise activation function.

  • :use_bias - whether the layer should add bias to the output. Defaults to true

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to 1.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :input_dilation - dilation to apply to input. Defaults to 1.

  • :kernel_dilation - dilation to apply to kernel. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Layers: Dropout

Link to this function

alpha_dropout(x, opts \\ [])

View Source

Adds an Alpha dropout layer to the network.

See Axon.Layers.alpha_dropout/2 for more details.

Options

  • :name - layer name.

  • :rate - dropout rate. Defaults to 0.5. Needs to be equal or greater than zero and less than one.

Adds a Dropout layer to the network.

See Axon.Layers.dropout/2 for more details.

Options

  • :name - layer name.

  • :rate - dropout rate. Defaults to 0.5. Needs to be equal or greater than zero and less than one.

Link to this function

feature_alpha_dropout(x, opts \\ [])

View Source

Adds a Feature alpha dropout layer to the network.

See Axon.Layers.feature_alpha_dropout/2 for more details.

Options

  • :name - layer name.

  • :rate - dropout rate. Defaults to 0.5. Needs to be equal or greater than zero and less than one.

Link to this function

spatial_dropout(x, opts \\ [])

View Source

Adds a Spatial dropout layer to the network.

See Axon.Layers.spatial_dropout/2 for more details.

Options

  • :name - layer name.

  • :rate - dropout rate. Defaults to 0.5. Needs to be equal or greater than zero and less than one.

Layers: Pooling

Link to this function

adaptive_avg_pool(x, opts \\ [])

View Source

Adds an Adaptive average pool layer to the network.

See Axon.Layers.adaptive_avg_pool/2 for more details.

Options

  • :name - layer name.

  • :output_size - layer output size.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

adaptive_lp_pool(x, opts \\ [])

View Source

Adds an Adaptive power average pool layer to the network.

See Axon.Layers.adaptive_lp_pool/2 for more details.

Options

  • :name - layer name.

  • :output_size - layer output size.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

adaptive_max_pool(x, opts \\ [])

View Source

Adds an Adaptive max pool layer to the network.

See Axon.Layers.adaptive_max_pool/2 for more details.

Options

  • :name - layer name.

  • :output_size - layer output size.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Adds an Average pool layer to the network.

See Axon.Layers.avg_pool/2 for more details.

Options

  • :name - layer name.

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to size of kernel.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :dilations - window dilations. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Link to this function

global_avg_pool(x, opts \\ [])

View Source

Adds a Global average pool layer to the network.

See Axon.Layers.global_avg_pool/2 for more details.

Typically used to connect feature extractors such as those in convolutional neural networks to fully-connected models by reducing inputs along spatial dimensions to only feature and batch dimensions.

Options

  • :name - layer name.

  • :keep_axes - option to keep reduced axes. If true, keeps reduced axes with a dimension size of 1.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

global_lp_pool(x, opts \\ [])

View Source

Adds a Global LP pool layer to the network.

See Axon.Layers.global_lp_pool/2 for more details.

Typically used to connect feature extractors such as those in convolutional neural networks to fully-connected models by reducing inputs along spatial dimensions to only feature and batch dimensions.

Options

  • :name - layer name.

  • :keep_axes - option to keep reduced axes. If true, keeps reduced axes with a dimension size of 1.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

global_max_pool(x, opts \\ [])

View Source

Adds a Global max pool layer to the network.

See Axon.Layers.global_max_pool/2 for more details.

Typically used to connect feature extractors such as those in convolutional neural networks to fully-connected models by reducing inputs along spatial dimensions to only feature and batch dimensions.

Options

  • :name - layer name.

  • :keep_axes - option to keep reduced axes. If true, keeps reduced axes with a dimension size of 1.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Adds a Power average pool layer to the network.

See Axon.Layers.lp_pool/2 for more details.

Options

  • :name - layer name.

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to size of kernel.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :dilations - window dilations. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Adds a Max pool layer to the network.

See Axon.Layers.max_pool/2 for more details.

Options

  • :name - layer name.

  • :kernel_size - size of the kernel spatial dimensions. Defaults to 1.

  • :strides - stride during convolution. Defaults to size of kernel.

  • :padding - padding to the spatial dimensions of the input. Defaults to :valid.

  • :dilations - window dilations. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.

Layers: Normalization

Link to this function

batch_norm(x, opts \\ [])

View Source

Adds a Batch normalization layer to the network.

See Axon.Layers.batch_norm/6 for more details.

Options

  • :name - layer name.

  • :gamma_initializer - gamma parameter initializer. Defaults to :glorot_uniform.

  • :beta_initializer - beta parameter initializer. Defaults to :zeros.

  • :channel_index - input feature index used for calculating mean and variance. Defaults to -1.

  • :epsilon - numerical stability term. Defaults to 1.0e-5.

Link to this function

group_norm(x, num_groups, opts \\ [])

View Source

Adds a group normalization layer to the network.

See Axon.Layers.group_norm/4 for more details.

Options

  • :name - layer name.

  • :gamma_initializer - gamma parameter initializer. Defaults to :glorot_uniform.

  • :beta_initializer - beta parameter initializer. Defaults to :zeros.

  • :channel_index - input feature index used for calculating mean and variance. Defaults to -1.

  • :epsilon - numerical stability term.

Link to this function

instance_norm(x, opts \\ [])

View Source

Adds an Instance normalization layer to the network.

See Axon.Layers.instance_norm/6 for more details.

Options

  • :name - layer name.

  • :gamma_initializer - gamma parameter initializer. Defaults to :glorot_uniform.

  • :beta_initializer - beta parameter initializer. Defaults to :zeros.

  • :channel_index - input feature index used for calculating mean and variance. Defaults to -1.

  • :epsilon - numerical stability term. Defaults to 1.0e-5.

Link to this function

layer_norm(x, opts \\ [])

View Source

Adds a Layer normalization layer to the network.

See Axon.Layers.layer_norm/4 for more details.

Options

  • :name - layer name.

  • :gamma_initializer - gamma parameter initializer. Defaults to :glorot_uniform.

  • :beta_initializer - beta parameter initializer. Defaults to :zeros.

  • :channel_index - input feature index used for calculating mean and variance. Defaults to -1.

  • :epsilon - numerical stability term.

Layers: Recurrent

See conv_lstm/3.

Link to this function

conv_lstm(x, units, opts)

View Source

Adds a convolutional long short-term memory (LSTM) layer to the network with a random initial hidden state.

See conv_lstm/4 for more details.

Additional options

  • :recurrent_initializer - initializer for hidden state. Defaults to :orthogonal.
Link to this function

conv_lstm(x, hidden_state, units, opts)

View Source

Adds a convolutional long short-term memory (LSTM) layer to the network with the given initial hidden state..

ConvLSTMs apply Axon.Layers.conv_lstm_cell/5 over an entire input sequence and return:

{{new_cell, new_hidden}, output_sequence}

You can use the output state as the hidden state of another ConvLSTM layer.

Options

  • :name - layer name.

  • :padding - convolutional padding. Defaults to :same.

  • :kernel_size - convolutional kernel size. Defaults to 1.

  • :strides - convolutional strides. Defaults to 1.

  • :unroll - :dynamic (loop preserving) or :static (compiled) unrolling of RNN.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

  • :use_bias - whether the layer should add bias to the output. Defaults to true.

See gru/3.

Adds a gated recurrent unit (GRU) layer to the network with a random initial hidden state.

See gru/4 for more details.

Additional options

  • :recurrent_initializer - initializer for hidden state. Defaults to :orthogonal.
Link to this function

gru(x, hidden_state, units, opts)

View Source

Adds a gated recurrent unit (GRU) layer to the network with the given initial hidden state.

GRUs apply Axon.Layers.gru_cell/7 over an entire input sequence and return:

{{new_hidden}, output_sequence}

You can use the output state as the hidden state of another GRU layer.

Options

  • :name - layer name.

  • :activation - recurrent activation. Defaults to :tanh.

  • :gate - recurrent gate function. Defaults to :sigmoid.

  • :unroll - :dynamic (loop preserving) or :static (compiled) unrolling of RNN.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

  • :use_bias - whether the layer should add bias to the output. Defaults to true.

See lstm/3.

Adds a long short-term memory (LSTM) layer to the network with a random initial hidden state.

See lstm/4 for more details.

Additional options

  • :recurrent_initializer - initializer for hidden state. Defaults to :orthogonal.
Link to this function

lstm(x, hidden_state, units, opts \\ [])

View Source

Adds a long short-term memory (LSTM) layer to the network with the given initial hidden state.

LSTMs apply Axon.Layers.lstm_cell/7 over an entire input sequence and return:

{output_sequence, {new_cell, new_hidden}}

You can use the output state as the hidden state of another LSTM layer.

Options

  • :name - layer name.

  • :activation - recurrent activation. Defaults to :tanh.

  • :gate - recurrent gate function. Defaults to :sigmoid.

  • :unroll - :dynamic (loop preserving) or :static (compiled) unrolling of RNN.

  • :kernel_initializer - initializer for kernel weights. Defaults to :glorot_uniform.

  • :bias_initializer - initializer for bias weights. Defaults to :zeros.

  • :use_bias - whether the layer should add bias to the output. Defaults to true.

Link to this function

mask(input, eos_token, opts \\ [])

View Source

Computes a sequence mask according to the given EOS token.

Masks can be propagated to recurrent layers or custom layers to indicate that a given token should be ignored in processing. This is useful when you have sequences of variable length.

Most commonly, eos_token is 0.

Options

  • :name - layer name.

Layers: Combinators

Adds a add layer to the network.

This layer performs an element-wise add operation on input layers. All input layers must be capable of being broadcast together.

If one shape has a static batch size, all other shapes must have a static batch size as well.

Options

  • :name - layer name.

Adds a concatenate layer to the network.

This layer will concatenate inputs along the last dimension unless specified otherwise.

Options

  • :name - layer name.

  • :axis - concatenate axis. Defaults to -1.

Link to this function

cond(parent, cond_fn, true_graph, false_graph, opts \\ [])

View Source

Adds a conditional layer which conditionally executes true_graph or false_graph based on the condition cond_fn at runtime.

cond_fn is an arity-1 function executed on the output of the parent graph. It must return a boolean scalar tensor (e.g. 1 or 0).

The shapes of true_graph and false_graph must be equal.

Adds a multiply layer to the network.

This layer performs an element-wise multiply operation on input layers. All input layers must be capable of being broadcast together.

If one shape has a static batch size, all other shapes must have a static batch size as well.

Options

  • :name - layer name.
Link to this function

split(parent, splits, opts \\ [])

View Source

Splits input graph into a container of n input graphs along the given axis.

Options

  • :name - layer name.

  • :axis - concatenate axis. Defaults to -1.

Adds a subtract layer to the network.

This layer performs an element-wise subtract operation on input layers. All input layers must be capable of being broadcast together.

If one shape has a static batch size, all other shapes must have a static batch size as well.

Options

  • :name - layer name.

Layers: Shape

Adds a flatten layer to the network.

This layer will flatten all but the batch dimensions of the input into a single layer. Typically called to flatten the output of a convolution for use with a dense layer.

Options

  • :name - layer name.
Link to this function

pad(x, config, value \\ 0.0, opts \\ [])

View Source

Adds a pad layer to the network.

This layer will pad the spatial dimensions of the input. Padding configuration is a list of tuples for each spatial dimension.

Options

  • :name - layer name.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

reshape(x, new_shape, opts \\ [])

View Source

Adds a reshape layer to the network.

This layer implements a special case of Nx.reshape which accounts for possible batch dimensions in the input tensor. You may pass the magic dimension :batch as a placeholder for dynamic batch sizes. You can use :batch seamlessly with :auto dimension sizes.

If the input is an Axon constant, the reshape behavior matches that of Nx.reshape/2.

Options

  • :name - layer name.
Link to this function

resize(x, resize_shape, opts \\ [])

View Source

Adds a resize layer to the network.

Resizing can be used for interpolation or upsampling input values in a neural network. For example, you can use this layer as an upsampling layer within a GAN.

Resize shape must be a tuple representing the resized spatial dimensions of the input tensor.

Compiles to Axon.Layers.resize/2.

Options

  • :name - layer name.

  • :method - resize method. Defaults to :nearest.

  • :antialias - whether an anti-aliasing filter should be used when downsampling. Defaults to true.

  • :channels - channel configuration. One of :first or :last. Defaults to :last.

Link to this function

transpose(x, permutation \\ nil, opts \\ [])

View Source

Adds a transpose layer to the network.

Options

  • :name - layer name.

Model

Link to this function

build(model, opts \\ [])

View Source

Builds the given model to {init_fn, predict_fn}.

The given functions can be either given as arguments to Nx.Defn functions or be invoked directly, to perform just-in-time compilation and execution. If you want to compile the model (instead of just-in-time) based on a predefined initialization shape, see compile/4.

init_fn

The init_fn receives two arguments, the input template and an optional map with initial parameters for layers or namespaces:

{init_fn, predict_fn} = Axon.build(model)
init_fn.(Nx.template({1, 1}, {:f, 32}), %{"dense_0" => dense_params})

predict_fn

The predict_fn receives two arguments, the trained parameters and the actual inputs:

{_init_fn, predict_fn} = Axon.build(model, opts)
predict_fn.(params, input)

Options

  • :compiler - the underlying Nx.Defn compiler to perform JIT compilation when the functions are invoked. If none is passed, it uses the default compiler configured in Nx.Defn;

  • :debug - if true, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to false

  • :mode - one of :inference or :train. Forwarded to layers to control differences in compilation at training or inference time. Defaults to :inference

  • :global_layer_options - a keyword list of options passed to layers that accept said options

All other options are forwarded to the underlying JIT compiler.

Link to this function

compile(model, template, init_params \\ %{}, opts \\ [])

View Source

Compiles the given model to {init_fn, predict_fn}.

This function will compile a model specialized to the given input shapes and types. This is useful for avoiding the overhead of long compilations at program runtime. You must provide template inputs which match the expected shapes and types of inputs at execution time.

This function makes use of the built-in Nx.Defn.compile/3. Note that passing inputs which differ in shape or type from the templates provided to this function will result in a crash.

Options

It accepts the same options as build/2.

Link to this function

deserialize(serialized, opts \\ [])

View Source

Deserializes serialized model and parameters into a {model, params} tuple.

It is the opposite of Axon.serialize/3.

Examples

iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> {init_fn, _} = Axon.build(model)
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> {_, predict_fn} = Axon.build(saved_model)
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
  f32[1][1]
  [
    [0.0]
  ]
>
Link to this function

freeze(model, fun_or_predicate \\ :all)

View Source

Freezes parameters returned from the given function or predicate.

fun can be a predicate :all, up: n, or down: n. :all freezes all parameters in the model, up: n freezes the first n layers up (starting from output), and down: n freezes the first n layers down (starting from input).

fun may also be a predicate function which takes a parameter and returns true if a parameter should be frozen or false otherwise.

Freezing parameters is useful when performing transfer learning to leverage features learned from another problem in a new problem. For example, it's common to combine the convolutional base from larger models trained on ImageNet with fresh fully-connected classifiers. The combined model is then trained on fresh data, with the convolutional base frozen so as not to lose information. You can see this example in code here:

cnn_base = get_pretrained_cnn_base()
model =
  cnn_base
  |> Axon.freeze()
  |> Axon.flatten()
  |> Axon.dense(1024, activation: :relu)
  |> Axon.dropout()
  |> Axon.dense(1000, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.005))
|> Axon.Loop.run(data, epochs: 10)

When compiled, frozen parameters are wrapped in Nx.Defn.Kernel.stop_grad/1, which zeros out the gradient with respect to the frozen parameter. Gradients of frozen parameters will return 0.0, meaning they won't be changed during the update process.

Link to this function

predict(model, params, input, opts \\ [])

View Source

Builds and runs the given Axon model with params and input.

This is equivalent to calling build/2 and then invoking the predict function.

Options

  • :mode - one of :inference or :train. Forwarded to layers to control differences in compilation at training or inference time. Defaults to :inference

  • :debug - if true, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to false

All other options are forwarded to the default JIT compiler or backend.

Link to this function

serialize(axon, params, opts \\ [])

View Source

Serializes a model and its parameters for persisting models to disk or elsewhere.

Model and parameters are serialized as a tuple, where the model is converted to a recursive map to ensure compatibility with future Axon versions and the parameters are serialized using Nx.serialize/2. There is some additional metadata included such as current serialization version for compatibility.

Serialization opts are forwarded to Nx.serialize/2 and :erlang.term_to_binary/2 for controlling compression options.

Examples

iex> model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, kernel_initializer: :zeros, activation: :relu)
iex> {init_fn, _} = Axon.build(model)
iex> params = init_fn.(Nx.template({1, 2}, :f32), %{})
iex> serialized = Axon.serialize(model, params)
iex> {saved_model, saved_params} = Axon.deserialize(serialized)
iex> {_, predict_fn} = Axon.build(saved_model)
iex> predict_fn.(saved_params, Nx.tensor([[1.0, 1.0]]))
#Nx.Tensor<
  f32[1][1]
  [
    [0.0]
  ]
>
Link to this function

unfreeze(model, fun_or_predicate \\ :all)

View Source

Unfreezes parameters returned from the given function or predicate.

fun can be a predicate :all, up: n, or down: n. :all freezes all parameters in the model, up: n unfreezes the first n layers up (starting from output), and down: n freezes the first n layers down (starting from input).

fun may also be a predicate function which takes a parameter and returns true if a parameter should be unfrozen or false otherwise.

Unfreezing parameters is useful when fine tuning a model which you have previously frozen and performed transfer learning on. You may want to unfreeze some of the later frozen layers in a model and fine tune them specifically for your application:

cnn_base = get_pretrained_cnn_base()
model =
  frozen_model
  |> Axon.unfreeze(up: 25)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.0005))
|> Axon.Loop.run(data, epochs: 10)

When compiled, frozen parameters are wrapped in Nx.Defn.Kernel.stop_grad/1, which zeros out the gradient with respect to the frozen parameter. Gradients of frozen parameters will return 0.0, meaning they won't be changed during the update process.

Model: Manipulation

Returns information about a model's inputs.

Returns a map of model op counts for each unique operation in a model by their given :op_name.

Examples

iex> model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
iex> Axon.get_op_counts(model)
%{input: 1, dense: 1}

iex> model = Axon.input("input", shape: {nil, 1}) |> Axon.tanh() |> Axon.tanh()
iex> Axon.get_op_counts(model)
%{input: 1, tanh: 2}

Returns a node's immediate input options.

Note that this does not take into account options of parent layers, only the option which belong to the immediate layer.

Link to this function

get_output_shape(axon, inputs, opts \\ [])

View Source

Returns a model's output shape from the given input template.

Returns a node's immediate parameters.

Note this does not take into account parameters of parent layers - only the parameters which belong to the immediate layer.

Traverses graph nodes in order, applying fun to each node exactly once to return a transformed node in its place(s) in the graph.

This function maintains an internal cache which ensures each node is only visited and transformed exactly once.

fun must accept an Axon node and return an Axon node.

Please note that modifying node lineage (e.g. altering a node's parent) will result in disconnected graphs.

Examples

One common use of this function is to implement common instrumentation between layers without needing to build a new explicitly instrumented version of a model. For example, you can use this function to visualize intermediate activations of all convolutional layers in a model:

instrumented_model = Axon.map_nodes(model, fn
  %Axon.Node{op: :conv} = axon_node ->
    Axon.attach_hook(axon_node, &visualize_activations/1)

  axon_node ->
    axon_node
end)

Another use case is to replace entire classes of layers with another. For example, you may want to replace all relu layers with tanh layers:

new_model = Axon.map_nodes(model, fn
  %Axon{op: :relu} = graph ->
    # Get nodes immediate parent
    parent = Axon.get_parent(graph)
    # Replace node with a tanh
    Axon.tanh(parent)

  graph ->
    graph
end)

Pops the top node off of the graph.

This returns the popped node and the updated graph:

{_node, model} = Axon.pop_node(model)
Link to this function

reduce_nodes(axon, acc, fun)

View Source

Traverses graph nodes in order, applying fun to each node exactly once to return a transformed node in its place(s) in the graph.

This function maintains an internal cache which ensures each node is only visited and transformed exactly once.

fun must accept an Axon node and accumulator and return an updated accumulator.

Examples

Internally this function is used in several places to accumulate graph metadata. For example, you can use it to count the number of a certain type of operation in the graph:

Axon.reduce_nodes(model, 0, fn
  %Axon.Nodes{op: :relu}, acc -> acc + 1
  _, acc -> acc
end)
Link to this function

set_options(axon, new_opts)

View Source

Sets a node's immediate options to the given input options.

Note that this does not take into account options of parent layers, only the option which belong to the immediate layer.

New options must be compatible with the given layer op. Adding unsupported options to an Axon layer will result in an error at graph execution time.

Link to this function

set_parameters(axon, new_params)

View Source

Sets a node's immediate parameters to the given parameters.

Note this does not take into account parameters of parent layers - only the parameters which belong to the immediate layer.

The new parameters must be compatible with the layer's old parameters.

Model: Debugging

Link to this function

attach_hook(x, fun, opts \\ [])

View Source

Attaches a hook to the given Axon model.

Hooks compile down to Nx.Defn.Kernel.hook/3 and provide the same functionality for adding side-effecting operations to a compiled model. For example, you can use hooks to inspect intermediate activations, send data to an external service, and more.

Hooks can be configured to be invoked on the following events:

  • :initialize - on model initialization.
  • :pre_forward - before layer forward pass is invoked.
  • :forward - after layer forward pass is invoked.
  • :backward - after layer backward pass is invoked.

To invoke a hook on every single event, you may pass :all to on:.

Axon.input("input", shape: {nil, 1}) |> Axon.attach_hook(&IO.inspect/1, on: :all)

The default event is :forward, assuming you want a hook invoked on the layers forward pass.

You may configure hooks to run in one of only training or inference mode using the :mode option. The default mode is :both to be invoked during both train and inference mode.

Axon.input("input", shape: {nil, 1}) |> Axon.attach_hook(&IO.inspect/1, on: :forward, mode: :train)

You can also attach multiple hooks to a single layer. Hooks are invoked in the order in which they are declared. If order is important, you should attach hooks in the order you want them to be executed:

Axon.input("input", shape: {nil, 1})
# I will be executed first
|> Axon.attach_hook(&IO.inspect/1)
# I will be executed second
|> Axon.attach_hook(fn _ -> IO.write("HERE") end)

Hooks are executed at their point of attachment. You must insert hooks at each point you want a hook to execute during model execution.

Axon.input("input", shape: {nil, 1})
|> Axon.attach_hook(&IO.inspect/1)
|> Axon.relu()
|> Axon.attach_hook(&IO.inspect/1)
Link to this function

trace_backward(model, inputs, params, loss, opts \\ [])

View Source

Compiles and returns the given model's backward function expression with respect to the given loss function.

The returned expression is an Nx expression which can be traversed and lowered to an IR or inspected for debugging purposes.

The given loss function must be a scalar loss function which expects inputs and targets with the same shapes as the model's output shapes as determined by the model's signature.

Options

  • :debug - if true, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to false
Link to this function

trace_forward(model, inputs, params, opts \\ [])

View Source

Compiles and returns the given model's forward function expression with the given options.

The returned expression is an Nx expression which can be traversed and lowered to an IR or inspected for debugging purposes.

Options

  • :mode - one of :inference or :train. Forwarded to layers to control differences in compilation at training or inference time. Defaults to :inference

  • :debug - if true, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to false

Link to this function

trace_init(model, template, params \\ %{}, opts \\ [])

View Source

Compiles and returns the given model's init function expression with the given options.

The returned expression is an Nx expression which can be traversed and lowered to an IR or inspected for debugging purposes.

You may optionally specify initial parameters for some layers or namespaces by passing a partial parameter map:

Axon.trace_init(model, %{"dense_0" => dense_params})

The parameter map will be merged with the initialized model parameters.

Options

  • :debug - if true, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to false

Types

@type t() :: %Axon{nodes: term(), output: term()}

Functions

Link to this function

bidirectional(input, forward_fun, merge_fun, opts \\ [])

View Source

Applies the given forward function bidirectionally and merges the results with the given merge function.

This is most commonly used with RNNs to capture the dependencies of a sequence in both directions.

Options

  • axis - Axis to reverse.
Link to this function

blur_pool(x, opts \\ [])

View Source

Adds a blur pooling layer to the network.

See Axon.Layers.blur_pool/2 for more details.

Options

  • :name - layer name.

  • :strides - stride during convolution. Defaults to 1.

  • :channels - channels location. One of :first or :last. Defaults to :last.