mlx/python/src/fast.cpp
Awni Hannun 1e16331d9c
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00

143 lines
4.8 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "mlx/fast.h"
#include "mlx/ops.h"
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
m.def(
"rms_norm",
&fast::rms_norm,
"x"_a,
"weight"_a,
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"layer_norm",
&fast::layer_norm,
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Layer normalization.
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array, optional): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no scaling happens.
bias (array, optional): An additive offset to be added to the result.
The ``bias`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no translation happens.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
&fast::rope,
"a"_a,
"dims"_a,
nb::kw_only(),
"traditional"_a,
"base"_a,
"scale"_a,
"offset"_a,
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.
Args:
a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
Returns:
array: The output array.
)pbdoc");
m.def(
"scaled_dot_product_attention",
&fast::scaled_dot_product_attention,
"q"_a,
"k"_a,
"v"_a,
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports:
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of
the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
Args:
q (array): Input query array.
k (array): Input keys array.
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");
}