mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Adds mx.fast.layer_norm (#870)
This commit is contained in:

committed by
GitHub

parent
105d236889
commit
2225374060
@@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"layer_norm",
|
||||
[](const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
const StreamOrDevice& s /* = {} */) {
|
||||
return fast::layer_norm(x, weight, bias, eps, s);
|
||||
},
|
||||
"x"_a,
|
||||
"weight"_a.none(),
|
||||
"bias"_a.none(),
|
||||
"eps"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Layer normalization.
|
||||
|
||||
The normalization is with respect to the last axis of the input ``x``.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
weight (array, optional): A multiplicative weight to scale the result by.
|
||||
The ``weight`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no scaling happens.
|
||||
bias (array, optional): An additive offset to be added to the result.
|
||||
The ``bias`` should be one-dimensional with the same size
|
||||
as the last axis of ``x``. If set to ``None`` then no translation happens.
|
||||
eps (float): A small additive constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
[](const array& a,
|
||||
|
Reference in New Issue
Block a user