Layers#

ALiBi()

BatchNorm(num_features[, eps, momentum, ...])

Applies Batch Normalization over a 2D or 3D input.

Conv1d(in_channels, out_channels, kernel_size)

Applies a 1-dimensional convolution over the multi-channel input sequence.

Conv2d(in_channels, out_channels, kernel_size)

Applies a 2-dimensional convolution over the multi-channel input image.

Dropout([p])

Randomly zero a portion of the elements during training.

Dropout2d([p])

Apply 2D channel-wise dropout during training.

Dropout3d([p])

Apply 3D channel-wise dropout during training.

Embedding(num_embeddings, dims)

Implements a simple lookup table that maps each input integer to a high-dimensional vector.

GELU([approx])

Applies the Gaussian Error Linear Units.

GroupNorm(num_groups, dims[, eps, affine, ...])

Applies Group Normalization [1] to the inputs.

InstanceNorm(dims[, eps, affine])

Applies instance normalization [1] on the inputs.

LayerNorm(dims[, eps, affine])

Applies layer normalization [1] on the inputs.

Linear(input_dims, output_dims[, bias])

Applies an affine transformation to the input.

Mish()

Applies the Mish function, element-wise.

MultiHeadAttention(dims, num_heads[, ...])

Implements the scaled dot product attention with multiple heads.

PReLU([num_parameters, init])

Applies the element-wise parametric ReLU.

QuantizedLinear(input_dims, output_dims[, ...])

Applies an affine transformation to the input using a quantized weight matrix.

RMSNorm(dims[, eps])

Applies Root Mean Square normalization [1] to the inputs.

ReLU()

Applies the Rectified Linear Unit.

RoPE(dims[, traditional, base, scale])

Implements the rotary positional encoding.

SELU()

Applies the Scaled Exponential Linear Unit.

Sequential(*modules)

A layer that calls the passed callables in order.

SiLU()

Applies the Sigmoid Linear Unit.

SinusoidalPositionalEncoding(dims[, ...])

Implements sinusoidal positional encoding.

Softshrink([lambd])

Applies the Softshrink function.

Step([threshold])

Applies the Step Activation Function.

Transformer(dims, num_heads, ...)

Implements a standard Transformer model.