Adds mx.fast.layer_norm (#870)

This commit is contained in:
Angelos Katharopoulos
2024-03-21 13:55:51 -07:00
committed by GitHub
parent 105d236889
commit 2225374060
11 changed files with 600 additions and 8 deletions

View File

@@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
array: The output array.
)pbdoc");
m.def(
"layer_norm",
[](const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::layer_norm(x, weight, bias, eps, s);
},
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Layer normalization.
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array, optional): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no scaling happens.
bias (array, optional): An additive offset to be added to the result.
The ``bias`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no translation happens.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
[](const array& a,