mlx.nn.glu#
- mlx.nn.glu(x: array, axis: int = - 1) array #
Applies the gated linear unit function.
This function splits the
axis
dimension of the input into two halves (\(a\) and \(b\)) and applies \(a * \sigma(b)\).\[\textrm{GLU}(x) = a * \sigma(b)\]- Parameters:
axis (int) – The dimension to split along. Default:
-1