mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
RoPE with frequencies as optional input (#1337)
* start rope with freq input * rope with frequencies * nits * fix bug * fix bug + test * cleanup * optional base
This commit is contained in:
@@ -79,26 +79,29 @@ void init_fast(nb::module_& parent_module) {
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
"traditional"_a,
|
||||
"base"_a,
|
||||
"base"_a.none(),
|
||||
"scale"_a,
|
||||
"offset"_a,
|
||||
"freqs"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
implementation which rotates consecutive dimensions.
|
||||
base (float): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings.
|
||||
implementation which rotates consecutive dimensions.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int): The position offset to start at.
|
||||
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. ``Default: None``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
@@ -115,7 +118,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user