Custom primitive + RoPE fat op (#676)

* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-02-14 14:04:25 -08:00 committed by GitHub
parent 1a48713d32
commit ccf1645995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 624 additions and 70 deletions

View File

@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@ -3,9 +3,10 @@ target_sources(
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp

View File

@ -11,6 +11,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp

View File

@ -0,0 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp

View File

@ -23,6 +23,7 @@ set(
"quantized" "quantized"
"random" "random"
"reduce" "reduce"
"rope"
"scan" "scan"
"softmax" "softmax"
"sort" "sort"

View File

@ -0,0 +1,68 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)

View File

@ -0,0 +1,55 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/fast.h"
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
@ -95,4 +96,8 @@ NO_GPU(Tan)
NO_GPU(Tanh) NO_GPU(Tanh)
NO_GPU(Transpose) NO_GPU(Transpose)
namespace fast {
NO_GPU_MULTI(RoPE)
} // namespace fast
} // namespace mlx::core } // namespace mlx::core

128
mlx/fast.cpp Normal file
View File

@ -0,0 +1,128 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/transforms.h"
namespace mlx::core::fast {
std::vector<array> Custom::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
}
std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
}
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}
auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
if (traditional) {
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
x.shape(),
x.dtype(),
std::make_unique<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
{x});
}
return fallback({x})[0];
}
bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other);
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_);
}
} // namespace mlx::core::fast

82
mlx/fast.h Normal file
View File

@ -0,0 +1,82 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
// Custom primitive accepts a fallback function which it uses for
// transformations. Transformations are virtual so that derived classes may to
// override the default behavior
class Custom : public Primitive {
public:
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback){};
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
virtual std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */);
class RoPE : public Custom {
public:
RoPE(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_;
bool traditional_;
float base_;
float scale_;
int offset_;
};
} // namespace mlx::core::fast

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/fast.h"
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/io.h" #include "mlx/io.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Optional from typing import Optional
@ -20,20 +20,13 @@ class RoPE(Module):
Args: Args:
dims (int): The feature dimensions to be rotated. If the input feature dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged. is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional traditional (bool, optional): If set to ``True`` choose the traditional
implementation which is slightly less efficient. Default: ``False``. implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``. each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``. scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
Attributes:
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
_cos_sin_theta_value (tuple): Cached cosine and sine values.
""" """
_cos_sin_theta_key = None
_cos_sin_theta_value = None
def __init__( def __init__(
self, self,
dims: int, dims: int,
@ -50,69 +43,18 @@ class RoPE(Module):
def _extra_repr(self): def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}" return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
shape = x.shape shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1])) x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset x = mx.fast.rope(
costheta, sintheta = RoPE.create_cos_sin_theta( x,
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
) )
return mx.reshape(x, shape)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@classmethod
def create_cos_sin_theta(
cls,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
half_D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
return cls._cos_sin_theta_value
class SinusoidalPositionalEncoding(Module): class SinusoidalPositionalEncoding(Module):

View File

@ -3,6 +3,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp

59
python/src/fast.cpp Normal file
View File

@ -0,0 +1,59 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/fast.h"
#include "mlx/ops.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_extensions(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
m.def(
"rope",
[](const array& a,
int dims,
bool traditional,
float base,
float scale,
int offset,
const StreamOrDevice& s /* = {} */) {
return fast::rope(a, dims, traditional, base, scale, offset, s);
},
"a"_a,
"dims"_a,
py::kw_only(),
"traditional"_a,
"base"_a,
"scale"_a,
"offset"_a,
"stream"_a = none,
R"pbdoc(
rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array
Apply rotary positional encoding to the input.
Args:
a (array): Input array.
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool): If set to ``True`` choose the traditional
implementation which rotates consecutive dimensions.
base (float): The base used to compute angular frequency for
each dimension in the positional encodings.
scale (float): The scale used to scale the positions.
offset (int): The position offset to start at.
Returns:
array: The output array.
)pbdoc");
}

View File

@ -17,6 +17,7 @@ void init_random(py::module_&);
void init_fft(py::module_&); void init_fft(py::module_&);
void init_linalg(py::module_&); void init_linalg(py::module_&);
void init_constants(py::module_&); void init_constants(py::module_&);
void init_extensions(py::module_&);
PYBIND11_MODULE(core, m) { PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon."; m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) {
init_fft(m); init_fft(m);
init_linalg(m); init_linalg(m);
init_constants(m); init_constants(m);
init_extensions(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = TOSTRING(_VERSION_);
} }

View File

@ -133,7 +133,7 @@ void init_random(py::module_& parent_module) {
low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. low (scalar or array, optional): Lower bound of the distribution. Default is ``0``.
high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. high (scalar or array, optional): Upper bound of the distribution. Default is ``1``.
shape (list(int), optional): Shape of the output. Default is ``()``. shape (list(int), optional): Shape of the output. Default is ``()``.
key (array, optional): A PRNG key. Default: None. key (array, optional): A PRNG key. Default: ``None``.
dtype (Dtype, optional): Type of the output. Default is ``float32``. dtype (Dtype, optional): Type of the output. Default is ``float32``.
Returns: Returns:

158
python/tests/test_fast.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
import mlx.core as mx
import mlx_tests
def rope_orig(x, dims, traditional, base, scale, offset):
N = x.shape[1] + offset
dtype = x.dtype
half_D = dims // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return mx.reshape(rx, x.shape)
else:
x1 = x[..., : dims // 2]
x2 = x[..., dims // 2 : dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
class TestFast(mlx_tests.MLXTestCase):
def test_rope(self):
T = 4
# Defaults: dims, dtype, base, scale, offset, traditional
defaults = (8, mx.float32, 10000.0, 1.0, 0, False)
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
# Test cases:
dtypes = [mx.float32, mx.float16, mx.bfloat16]
bases = [10000.0, 1000000.0]
scales = [1.0, 2.0]
offsets = [0, 3]
traditional = [True, False]
for traditional in [True, False]:
dims, dtype, _, scale, offset, _ = defaults
for base in bases:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, _, base, scale, offset, _ = defaults
for dtype in dtypes:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
ry = rope_orig(
x.astype(mx.float32), dims, traditional, base, scale, offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
if dtype != mx.float32:
self.assertLessEqual(
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, dtype, base, scale, _, _ = defaults
for offset in offsets:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, dtype, base, _, offset, _ = defaults
for scale in scales:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))
defaults = (8, False, 10000.0, 1.0, 0)
dims, traditional, base, scale, offset = defaults
# VJP
_, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
_, vjp_fast_out = mx.vjp(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
),
(x,),
(mx.ones_like(x),),
)
self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0]))
# JVP
_, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),))
_, jvp_fast_out = mx.jvp(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
),
(x,),
(mx.ones_like(x),),
)
self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0]))
# VMAP
x = mx.random.uniform(shape=(2, 2, 2, 8))
vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x)
vmap_fast_out = mx.vmap(
lambda x: mx.fast.rope(
x, dims, traditional=traditional, base=base, scale=scale, offset=offset
)
)(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
if __name__ == "__main__":
unittest.main()