Fast RMS Norm (#862)

* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-03-21 07:20:54 -07:00
committed by GitHub
parent 4650d94d98
commit a54f06b16f
17 changed files with 493 additions and 41 deletions

View File

@@ -15,6 +15,37 @@ void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
m.def(
"rms_norm",
[](const array& x,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
"x"_a,
"weight"_a,
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
[](const array& a,