From 79b527f45f7dfdd76a78b33a5ca5d5ee8e122095 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 21 Apr 2025 13:04:39 -0700 Subject: [PATCH 001/156] conv vmap (#2102) --- mlx/primitives.cpp | 55 +++++++++++++++++++++++++++++++++++++++ mlx/primitives.h | 1 + python/tests/test_vmap.py | 51 ++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 590af60f6..3d36f0881 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1275,6 +1275,61 @@ std::vector Convolution::vjp( return grads; } +std::pair, std::vector> Convolution::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto do_conv = [&](const array& in, const array& w, int groups) { + return conv_general( + in, + w, + kernel_strides_, + padding_, + kernel_dilation_, + input_dilation_, + groups, + flip_, + stream()); + }; + bool in_vmap = axes[0] >= 0; + bool w_vmap = axes[1] >= 0; + auto in = inputs[0]; + auto w = inputs[1]; + if (in_vmap && !w_vmap) { + // flatten / unflatten the batch dimension + // of the input / output + if (axes[0] > 0) { + in = moveaxis(in, axes[0], 0, stream()); + } + auto out = do_conv(flatten(in, 0, 1, stream()), w, groups_); + out = unflatten(out, 0, {in.shape(0), in.shape(1)}, stream()); + return {{out}, {0}}; + } else if (!in_vmap && w_vmap) { + // flatten into the output channels of w + // unflatten the channels of the output + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto out = do_conv(in, flatten(w, 0, 1, stream()), groups_); + out = unflatten(out, -1, {w.shape(0), w.shape(1)}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else if (in_vmap && w_vmap) { + // use a group convolution when both inputs are vmapped + auto b = in.shape(axes[0]); + in = moveaxis(in, axes[0], -2, stream()); + in = flatten(in, -2, -1, stream()); + if (axes[1] > 0) { + w = moveaxis(w, axes[1], 0, stream()); + } + auto c_out = w.shape(1); + w = flatten(w, 0, 1, stream()); + auto out = do_conv(in, w, groups_ * b); + out = unflatten(out, -1, {b, c_out}, stream()); + return {{out}, {static_cast(out.ndim() - 2)}}; + } else { + return {{do_conv(in, w, groups_)}, {-1}}; + } +} + bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); return padding_ == c_other.padding_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 997931f30..3753e43c5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -711,6 +711,7 @@ class Convolution : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 1a1ba23b3..e571678d3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -669,6 +669,57 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8)) self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6)) + def test_vmap_conv(self): + # vmap input only + x = mx.random.uniform(shape=(2, 2, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, w) for xi in x]) + out = mx.vmap(mx.conv1d, in_axes=(0, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.moveaxis(x, 0, 2) + out = mx.vmap(mx.conv1d, in_axes=(2, None))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights only + x = mx.random.uniform(shape=(2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(x, wi) for wi in w]) + out = mx.vmap(mx.conv1d, in_axes=(None, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + w = mx.moveaxis(w, 0, 1) + out = mx.vmap(mx.conv1d, in_axes=(None, 1))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # vmap weights and input + x = mx.random.uniform(shape=(3, 2, 5, 4)) + w = mx.random.uniform(shape=(3, 8, 3, 4)) + + expected = mx.stack([mx.conv1d(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(mx.conv1d, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + x = mx.random.uniform(shape=(2, 3, 5, 4)) + w = mx.random.uniform(shape=(8, 3, 4, 3)) + + expected = mx.stack([mx.conv1d(x[:, i], w[..., i]) for i in range(3)]) + out = mx.vmap(mx.conv1d, in_axes=(1, 3))(x, w) + self.assertTrue(mx.allclose(expected, out)) + + # Test with groups + x = mx.random.uniform(shape=(3, 2, 5, 8)) + w = mx.random.uniform(shape=(3, 2, 3, 4)) + + def gconv(x, w): + return mx.conv1d(x, w, groups=2) + + expected = mx.stack([gconv(xi, wi) for xi, wi in zip(x, w)]) + out = mx.vmap(gconv, in_axes=(0, 0))(x, w) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": unittest.main() From fdadc4f22c19390fb171898098bde81b7d1d8db8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 21 Apr 2025 13:04:54 -0700 Subject: [PATCH 002/156] Add more complex unary ops (#2101) --- mlx/backend/metal/kernels/complex.h | 19 +++++++ mlx/backend/metal/kernels/unary.metal | 6 +++ mlx/backend/metal/kernels/unary_ops.h | 76 ++++++++++++++------------- python/tests/test_ops.py | 29 ++++++++++ 4 files changed, 93 insertions(+), 37 deletions(-) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index fe8ec5c0f..c88002cb3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -104,10 +104,22 @@ constexpr bool operator==(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) { return {a.real + b.real, a.imag + b.imag}; } +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} constexpr complex64_t operator-(complex64_t a, complex64_t b) { return {a.real - b.real, a.imag - b.imag}; } +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; @@ -120,6 +132,13 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { return {x / denom, y / denom}; } +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 2209b0665..d34c5a7ec 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -69,6 +69,9 @@ instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) instantiate_unary_all_same(Abs, complex64, complex64_t) +instantiate_unary_all_same(ArcCos, complex64, complex64_t) +instantiate_unary_all_same(ArcSin, complex64, complex64_t) +instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) @@ -80,6 +83,9 @@ instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_all_same(Sinh, complex64, complex64_t) +instantiate_unary_all_same(Square, complex64, complex64_t) +instantiate_unary_all_same(Sqrt, complex64, complex64_t) +instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_all_same(Round, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 52e126b40..09d9f6605 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -17,27 +17,21 @@ struct Abs { T operator()(T x) { return metal::abs(x); }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; - template <> complex64_t operator()(complex64_t x) { return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; }; @@ -48,6 +42,8 @@ struct ArcCos { T operator()(T x) { return metal::precise::acos(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcCosh { @@ -62,6 +58,8 @@ struct ArcSin { T operator()(T x) { return metal::precise::asin(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcSinh { @@ -76,6 +74,8 @@ struct ArcTan { T operator()(T x) { return metal::precise::atan(x); }; + + complex64_t operator()(complex64_t x); }; struct ArcTanh { @@ -97,39 +97,30 @@ struct Ceil { T operator()(T x) { return metal::ceil(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -141,7 +132,6 @@ struct Cos { return metal::precise::cos(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cos(x.real) * metal::precise::cosh(x.imag), @@ -155,7 +145,6 @@ struct Cosh { return metal::precise::cosh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::cosh(x.real) * metal::precise::cos(x.imag), @@ -188,7 +177,6 @@ struct Exp { T operator()(T x) { return metal::precise::exp(x); }; - template <> complex64_t operator()(complex64_t x) { auto m = metal::precise::exp(x.real); return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; @@ -207,39 +195,30 @@ struct Floor { T operator()(T x) { return metal::floor(x); }; - template <> int8_t operator()(int8_t x) { return x; }; - template <> int16_t operator()(int16_t x) { return x; }; - template <> int32_t operator()(int32_t x) { return x; }; - template <> int64_t operator()(int64_t x) { return x; }; - template <> uint8_t operator()(uint8_t x) { return x; }; - template <> uint16_t operator()(uint16_t x) { return x; }; - template <> uint32_t operator()(uint32_t x) { return x; }; - template <> uint64_t operator()(uint64_t x) { return x; }; - template <> bool operator()(bool x) { return x; }; @@ -258,7 +237,6 @@ struct Log { return metal::precise::log(x); }; - template <> complex64_t operator()(complex64_t x) { auto r = metal::precise::log(Abs{}(x).real); auto i = metal::precise::atan2(x.imag, x.real); @@ -272,7 +250,6 @@ struct Log2 { return metal::precise::log2(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN2_F, y.imag / M_LN2_F}; @@ -285,7 +262,6 @@ struct Log10 { return metal::precise::log10(x); }; - template <> complex64_t operator()(complex64_t x) { auto y = Log{}(x); return {y.real / M_LN10_F, y.imag / M_LN10_F}; @@ -325,7 +301,6 @@ struct Round { T operator()(T x) { return metal::rint(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; @@ -344,11 +319,9 @@ struct Sign { T operator()(T x) { return (x > T(0)) - (x < T(0)); }; - template <> uint32_t operator()(uint32_t x) { return x != 0; }; - template <> complex64_t operator()(complex64_t x) { if (x == complex64_t(0)) { return x; @@ -364,7 +337,6 @@ struct Sin { return metal::precise::sin(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sin(x.real) * metal::precise::cosh(x.imag), @@ -378,7 +350,6 @@ struct Sinh { return metal::precise::sinh(x); }; - template <> complex64_t operator()(complex64_t x) { return { metal::precise::sinh(x.real) * metal::precise::cos(x.imag), @@ -398,6 +369,17 @@ struct Sqrt { T operator()(T x) { return metal::precise::sqrt(x); }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } }; struct Rsqrt { @@ -405,6 +387,10 @@ struct Rsqrt { T operator()(T x) { return metal::precise::rsqrt(x); }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } }; struct Tan { @@ -413,7 +399,6 @@ struct Tan { return metal::precise::tan(x); }; - template <> complex64_t operator()(complex64_t x) { float tan_a = metal::precise::tan(x.real); float tanh_b = metal::precise::tanh(x.imag); @@ -429,7 +414,6 @@ struct Tanh { return metal::precise::tanh(x); }; - template <> complex64_t operator()(complex64_t x) { float tanh_a = metal::precise::tanh(x.real); float tan_b = metal::precise::tan(x.imag); @@ -438,3 +422,21 @@ struct Tanh { return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; }; }; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4fcb31f18..31ea79345 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2934,6 +2934,35 @@ class TestOps(mlx_tests.MLXTestCase): out = a[::-1] self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + def test_complex_ops(self): + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 0.0 + 0.0j, + ] + ) + + ops = ["arccos", "arcsin", "arctan", "square", "sqrt"] + for op in ops: + with self.subTest(op=op): + np_op = getattr(np, op) + mx_op = getattr(mx, op) + self.assertTrue(np.allclose(mx_op(x), np_op(x))) + + x = mx.array( + [ + 3.0 + 4.0j, + -5.0 + 12.0j, + -8.0 + 0.0j, + 0.0 + 9.0j, + 9.0 + 1.0j, + ] + ) + self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + if __name__ == "__main__": unittest.main() From e8ac6bd2f53ea2fa6df1221d2f4c595ac2c7069c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Apr 2025 10:25:55 -0700 Subject: [PATCH 003/156] irfft throws instead of segfaults on scalars (#2109) --- mlx/fft.cpp | 3 +-- python/tests/test_fft.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index f0d41bf0f..02878af9c 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include @@ -109,7 +108,7 @@ array fft_impl( for (auto ax : axes) { n.push_back(a.shape(ax)); } - if (real && inverse) { + if (real && inverse && a.ndim() > 0) { n.back() = (n.back() - 1) * 2; } return fft_impl(a, n, axes, real, inverse, s); diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index ec9a48f00..c887cd968 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase): r_np = np.fft.ifft(segment, n=n_fft) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + def test_fft_throws(self): + x = mx.array(3.0) + with self.assertRaises(ValueError): + mx.fft.irfftn(x) + if __name__ == "__main__": unittest.main() From 1d2c9d6a07f76b82b3c80c6a6682ba7f7161865d Mon Sep 17 00:00:00 2001 From: Yury Popov Date: Wed, 23 Apr 2025 04:56:28 +0300 Subject: [PATCH 004/156] Complex scan (#2094) --- mlx/backend/cpu/binary.cpp | 5 +- mlx/backend/cpu/scan.cpp | 3 +- mlx/backend/cpu/simd/base_simd.h | 23 +++++++- mlx/backend/metal/kernels/binary.metal | 1 + mlx/backend/metal/kernels/binary_ops.h | 18 ++++++ mlx/backend/metal/kernels/scan.metal | 3 +- mlx/backend/metal/kernels/unary.metal | 1 + mlx/backend/metal/kernels/utils.h | 17 ++++++ python/tests/test_ops.py | 79 ++++++++++++++++++++++++++ 9 files changed, 146 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index dbdab6a06..35aa2a3e0 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -172,9 +172,12 @@ void binary_float( case bfloat16: binary_op(a, b, out, bopt); break; + case complex64: + binary_op(a, b, out, bopt); + break; default: throw std::runtime_error( - "[binary_float] Only supports non-complex floating point types."); + "[binary_float] Only supports floating point types."); } }); } diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 199dbab35..33addd161 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -330,7 +330,8 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { reduce_type_, in, out, axis_, reverse_, inclusive_); break; case complex64: - throw std::runtime_error("Scan ops do not support complex types yet"); + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); break; } }); diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 7e82a4d56..17cd35b9a 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) -DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) { diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 3ef8e6269..1d555fefa 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -71,6 +71,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) +instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 8f961c2cf..4aaf2b4da 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -130,6 +130,24 @@ struct LogAddExp { ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 8fcd7f61b..f38f8757e 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -104,4 +104,5 @@ instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMi instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4) instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4) -instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on +instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) +instantiate_scan_helper(logaddexp_complex64_complex64, complex64_t, complex64_t, CumLogaddexp, 2) // clang-format on diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index d34c5a7ec..afced7eb7 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -77,6 +77,7 @@ instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_all_same(Log, complex64, complex64_t) +instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_all_same(Negative, complex64, complex64_t) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b31cd20d6..1170d5576 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -328,6 +328,23 @@ inline bfloat16_t log1p(bfloat16_t x) { return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 31ea79345..d0e52eab2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -10,6 +10,47 @@ import mlx_tests import numpy as np +def np_wrap_between(x, a): + """Wraps `x` between `[-a, a]`.""" + two_a = 2 * a + zero = 0 + rem = np.remainder(np.add(x, a), two_a) + if isinstance(rem, np.ndarray): + rem = np.select(rem < zero, np.add(rem, two_a), rem) + else: + rem = np.add(rem, two_a) if rem < zero else rem + return np.subtract(rem, a) + + +def np_logaddexp(x1: np.ndarray, x2: np.ndarray): + amax = np.maximum(x1, x2) + if np.issubdtype(x1.dtype, np.floating): + delta = np.subtract(x1, x2) + if isinstance(delta, np.ndarray): + return np.select( + np.isnan(delta), + np.add(x1, x2), + np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))), + ) + else: + return ( + np.add(x1, x2) + if np.isnan(delta) + else np.add(amax, np.log1p(np.exp(np.negative(np.abs(delta))))) + ) + else: + delta = np.subtract(np.add(x1, x2), np.multiply(amax, 2)) + out = np.add(amax, np.log1p(np.exp(delta))) + return np.real(out) + 1j * np_wrap_between(np.imag(out), np.pi) + + +def np_cumlogaddexp(x1: np.ndarray, axis: int = -1): + out = x1.copy() + for i in range(1, out.shape[axis]): + out[i] = np_logaddexp(out[i], out[i - 1]) + return out + + class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) @@ -853,6 +894,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + + a = mx.array([0, 1, 2, 9.0]) + 1j + b = mx.array([1, 0, 4, 2.5]) + 1j + + result = mx.logaddexp(a, b) + expected = np_logaddexp(np.array(a), np.array(b)) + + self.assertTrue(np.allclose(result, expected)) + a = mx.array([float("nan")]) b = mx.array([0.0]) self.assertTrue(math.isnan(mx.logaddexp(a, b).item())) @@ -977,6 +1028,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(result, expected)) + # Complex test + a = mx.array([1, 0.5, 10, 100]) + 1j + result = mx.log1p(a) + expected = np.log1p(a, dtype=np.complex64) + + self.assertTrue(np.allclose(result, expected)) + def test_sigmoid(self): a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) result = mx.sigmoid(a) @@ -1881,10 +1939,31 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=0) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + # Complex tests + + a_npy = np.array([1, 2, 3]).astype(np.float32) + 1j + a_mlx = mx.array(a_npy) + c_npy = np_cumlogaddexp(a_npy, axis=-1) + c_mlx = mxop(a_mlx, axis=-1) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + + # Complex test + + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + 0.5j + a_mlx = mx.array(a_npy) + for op in ["cumsum", "cumprod"]: npop = getattr(np, op) mxop = getattr(mx, op) From 383644524150f39c96008ba6645201ec23bb44d6 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Wed, 23 Apr 2025 10:57:39 +0900 Subject: [PATCH 005/156] Add broadcast_shapes in python API (#2091) --- python/src/ops.cpp | 42 ++++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aa..5969c5052 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5189,4 +5189,46 @@ void init_ops(nb::module_& m) { Returns: array: The row or col contiguous output. )pbdoc"); + m.def( + "broadcast_shapes", + [](const nb::args& shapes) { + if (shapes.size() == 0) + throw std::invalid_argument( + "[broadcast_shapes] Must provide at least one shape."); + + mx::Shape result = nb::cast(shapes[0]); + for (size_t i = 1; i < shapes.size(); ++i) { + if (!nb::isinstance(shapes[i]) && + !nb::isinstance(shapes[i])) + throw std::invalid_argument( + "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); + result = mx::broadcast_shapes(result, nb::cast(shapes[i])); + } + + return nb::tuple(nb::cast(result)); + }, + nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), + R"pbdoc( + Broadcast shapes. + + Returns the shape that results from broadcasting the supplied array shapes + against each other. + + Args: + *shapes (Sequence[int]): The shapes to broadcast. + + Returns: + tuple: The broadcasted shape. + + Raises: + ValueError: If the shapes cannot be broadcast. + + Example: + >>> mx.broadcast_shapes((1,), (3, 1)) + (3, 1) + >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) + (5, 6, 7) + >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) + (5, 3, 4) + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d0e52eab2..47fec3167 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3043,5 +3043,45 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) +class TestBroadcast(mlx_tests.MLXTestCase): + def test_broadcast_shapes(self): + # Basic broadcasting + self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) + self.assertEqual(mx.broadcast_shapes((4, 1, 6), (5, 6)), (4, 5, 6)) + self.assertEqual(mx.broadcast_shapes((5, 1, 4), (1, 3, 4)), (5, 3, 4)) + + # Multiple arguments + self.assertEqual(mx.broadcast_shapes((1, 1), (1, 8), (7, 1)), (7, 8)) + self.assertEqual( + mx.broadcast_shapes((6, 1, 5), (1, 7, 1), (6, 7, 5)), (6, 7, 5) + ) + + # Same shapes + self.assertEqual(mx.broadcast_shapes((3, 4, 5), (3, 4, 5)), (3, 4, 5)) + + # Single argument + self.assertEqual(mx.broadcast_shapes((2, 3)), (2, 3)) + + # Empty shapes + self.assertEqual(mx.broadcast_shapes((), ()), ()) + self.assertEqual(mx.broadcast_shapes((), (1,)), (1,)) + self.assertEqual(mx.broadcast_shapes((1,), ()), (1,)) + + # Broadcasting with zeroes + self.assertEqual(mx.broadcast_shapes((0,), (0,)), (0,)) + self.assertEqual(mx.broadcast_shapes((1, 0, 5), (3, 1, 5)), (3, 0, 5)) + self.assertEqual(mx.broadcast_shapes((5, 0), (0, 5, 0)), (0, 5, 0)) + + # Error cases + with self.assertRaises(ValueError): + mx.broadcast_shapes((3, 4), (4, 3)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes((2, 3, 4), (2, 5, 4)) + + with self.assertRaises(ValueError): + mx.broadcast_shapes() + + if __name__ == "__main__": unittest.main() From 600e87e03c2ec57bd71139331a4cd8ae0bb99929 Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Wed, 23 Apr 2025 21:56:33 +0530 Subject: [PATCH 006/156] Added output_padding parameters in conv_transpose (#2092) --- mlx/ops.cpp | 13 +- mlx/ops.h | 3 + python/mlx/nn/layers/convolution_transpose.py | 44 +++- python/src/ops.cpp | 43 +++- python/tests/test_conv_transpose.py | 209 ++++++++++++++++++ tests/ops_tests.cpp | 68 +++++- 6 files changed, 366 insertions(+), 14 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ac62fef..c2aa4786f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3769,6 +3769,7 @@ array conv_transpose_general( std::vector stride, std::vector padding, std::vector dilation, + std::vector output_padding, int groups, StreamOrDevice s) { std::vector padding_lo(padding.size()); @@ -3782,7 +3783,8 @@ array conv_transpose_general( int in_size = 1 + (conv_output_shape - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding[i]; + padding_hi[i] = in_size - out_size + padding[i] + + output_padding[i]; // Adjust with output_padding } return conv_general( @@ -3805,10 +3807,11 @@ array conv_transpose1d( int stride /* = 1 */, int padding /* = 0 */, int dilation /* = 1 */, + int output_padding /* = 0 */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( - in_, wt_, {stride}, {padding}, {dilation}, groups, s); + in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s); } /** 2D transposed convolution with a filter */ @@ -3818,6 +3821,7 @@ array conv_transpose2d( const std::pair& stride /* = {1, 1} */, const std::pair& padding /* = {0, 0} */, const std::pair& dilation /* = {1, 1} */, + const std::pair& output_padding /* = {0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3826,6 +3830,7 @@ array conv_transpose2d( {stride.first, stride.second}, {padding.first, padding.second}, {dilation.first, dilation.second}, + {output_padding.first, output_padding.second}, groups, s); } @@ -3837,6 +3842,7 @@ array conv_transpose3d( const std::tuple& stride /* = {1, 1, 1} */, const std::tuple& padding /* = {0, 0, 0} */, const std::tuple& dilation /* = {1, 1, 1} */, + const std::tuple& output_padding /* = {0, 0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3845,6 +3851,9 @@ array conv_transpose3d( {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + {std::get<0>(output_padding), + std::get<1>(output_padding), + std::get<2>(output_padding)}, groups, s); } diff --git a/mlx/ops.h b/mlx/ops.h index e79ea235d..12e896af6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1291,6 +1291,7 @@ array conv_transpose1d( int stride = 1, int padding = 0, int dilation = 1, + int output_padding = 0, int groups = 1, StreamOrDevice s = {}); @@ -1301,6 +1302,7 @@ array conv_transpose2d( const std::pair& stride = {1, 1}, const std::pair& padding = {0, 0}, const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1311,6 +1313,7 @@ array conv_transpose3d( const std::tuple& stride = {1, 1, 1}, const std::tuple& padding = {0, 0, 0}, const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, int groups = 1, StreamOrDevice s = {}); diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index edacab061..a11c4cb40 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,6 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. + output_padding(int, optional): Additional size added to one side of the + output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -37,6 +39,7 @@ class ConvTranspose1d(Module): stride: int = 1, padding: int = 0, dilation: int = 1, + output_padding: int = 0, bias: bool = True, ): super().__init__() @@ -53,18 +56,25 @@ class ConvTranspose1d(Module): self.padding = padding self.dilation = dilation self.stride = stride + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose1d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -90,6 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -102,13 +114,14 @@ class ConvTranspose2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -122,18 +135,25 @@ class ConvTranspose2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose2d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -160,6 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -172,13 +194,14 @@ class ConvTranspose3d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt( 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) @@ -194,18 +217,25 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose3d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 5969c5052..60b6188ed 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3609,11 +3609,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 1D transposed convolution over an input with several channels @@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) { stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. + output_padding (int, optional): Output padding. Default: ``0``. groups (int, optional): Input feature groups. Default: ``1``. Returns: @@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; + std::pair output_padding_pair{0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; @@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_pair = std::pair{*pv, *pv}; + } else { + output_padding_pair = std::get>(output_padding); + } + return mx::conv_transpose2d( - input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + input, + weight, + stride_pair, + padding_pair, + dilation_pair, + output_padding_pair, + groups, + s); }, nb::arg(), nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D transposed convolution over an input with several channels @@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; + std::tuple output_padding_tuple{0, 0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_tuple = std::tuple{*pv, *pv, *pv}; @@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + output_padding_tuple = + std::get>(output_padding); + } + return mx::conv_transpose3d( input, weight, stride_tuple, padding_tuple, dilation_tuple, + output_padding_tuple, groups, s); }, @@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D transposed convolution over an input with several channels @@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 1ac20cbb1..2085e09d7 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_tranpose_1d_output_padding(self): + def run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for iH, kH, stride, padding, output_padding in ( + (3, 2, 2, 0, 1), + (5, 3, 2, 1, 0), + (7, 4, 3, 1, 2), + ): + run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2d_output_padding(self): + def run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), + ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), + ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), + ): + run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3d_output_padding(self): + def run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), + ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), + ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), + ): + run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index de0f3352c..c4f319d46 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3911,4 +3911,70 @@ TEST_CASE("test bitwise shift operations") { CHECK_EQ(right_shift_bool_result.dtype(), uint8); CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); -} \ No newline at end of file +} + +TEST_CASE("test conv_transpose1d with output_padding") { + auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); + auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3}); + int stride = 2; + int padding = 0; + int dilation = 1; + int output_padding = 1; + int groups = 1; + + auto out = conv_transpose1d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array({6.0, 0.0}, {1, 2, 1}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose2d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2}); + std::pair stride{2, 2}; + std::pair padding{0, 0}; + std::pair output_padding{1, 1}; + std::pair dilation{1, 1}; + int groups = 1; + + auto out = conv_transpose2d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, + 3.0, + 0.0, + 0.0, + 7.0, + 7.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}, + {1, 2, 4, 2}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose3d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2}); + std::tuple stride{2, 2, 2}; + std::tuple padding{0, 0, 0}; + std::tuple output_padding{1, 1, 1}; + std::tuple dilation{1, 1, 1}; + int groups = 1; + + auto out = conv_transpose3d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 2, 4, 4, 1}); + CHECK(array_equal(out, expected).item()); +} From 38c1e720c262aa44e08edca30903fdcf81c17c05 Mon Sep 17 00:00:00 2001 From: hdeng-apple Date: Thu, 24 Apr 2025 00:53:13 +0800 Subject: [PATCH 007/156] Search mlx.metallib in macOS framework "Resources" dir (#2061) --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/metal/device.cpp | 48 ++++++++++++++++++++++-------------- mlx/backend/metal/device.h | 14 ++++------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 95aeb1cc9..43f82893b 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include @@ -15,6 +16,8 @@ #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" +namespace fs = std::filesystem; + namespace mlx::core::metal { namespace { @@ -79,12 +82,18 @@ MTL::Library* try_load_bundle( // Firstly, search for the metallib in the same path as this binary std::pair load_colocated_library( MTL::Device* device, - const std::string& lib_name) { - std::string lib_path = get_colocated_mtllib_path(lib_name); - if (lib_path.size() != 0) { - return load_library_from_path(device, lib_path.c_str()); + const std::string& relative_path) { + std::string binary_dir = get_binary_directory(); + if (binary_dir.size() == 0) { + return {nullptr, nullptr}; } - return {nullptr, nullptr}; + + auto path = fs::path(binary_dir) / relative_path; + if (!path.has_extension()) { + path.replace_extension(".metallib"); + } + + return load_library_from_path(device, path.c_str()); } std::pair load_swiftpm_library( @@ -109,33 +118,34 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { - NS::Error *error1, *error2, *error3; + NS::Error* error[4]; MTL::Library* lib; // First try the colocated mlx.metallib - std::tie(lib, error1) = load_colocated_library(device, "mlx"); + std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx"); if (lib) { return lib; } // Then try default.metallib in a SwiftPM bundle if we have one - std::tie(lib, error2) = load_swiftpm_library(device, "default"); + std::tie(lib, error[2]) = load_swiftpm_library(device, "default"); if (lib) { return lib; } // Finally try default_mtllib_path - std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; - if (error1 != nullptr) { - msg << error1->localizedDescription()->utf8String() << " "; - } - if (error2 != nullptr) { - msg << error2->localizedDescription()->utf8String() << " "; - } - if (error3 != nullptr) { - msg << error3->localizedDescription()->utf8String() << " "; + for (int i = 0; i < 4; i++) { + if (error[i] != nullptr) { + msg << error[i]->localizedDescription()->utf8String() << " "; + } } throw std::runtime_error(msg.str()); } @@ -188,8 +198,8 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) - << ">"; + << "We attempted to load it from <" << get_binary_directory() << "/" + << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; #endif diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bb0e93147..d60635e39 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -21,18 +21,14 @@ namespace mlx::core::metal { // Note, this function must be left inline in a header so that it is not // dynamically linked. -inline std::string get_colocated_mtllib_path(const std::string& lib_name) { +inline std::string get_binary_directory() { Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); + std::string directory; + int success = dladdr((void*)get_binary_directory, &info); if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); + directory = fs::path(info.dli_fname).remove_filename().c_str(); } - - return mtllib_path; + return directory; } using MTLFCList = From fbc89e3ced24f8a8bf0324bf691ce53da9243868 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 23 Apr 2025 13:08:28 -0700 Subject: [PATCH 008/156] fix pinv (#2110) --- mlx/linalg.cpp | 7 ++++++- mlx/types/limits.h | 6 ++++++ mlx/utils.cpp | 13 +++++++------ mlx/utils.h | 1 + python/src/array.cpp | 7 +++++++ python/tests/test_array.py | 2 ++ python/tests/test_linalg.py | 5 +++++ 7 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 5b9b51ad3..53f13486a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { // Prepare S S = expand_dims(S, -2, s); - return matmul(divide(V, S, s), U); + auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps; + auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s); + auto rS = + where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s); + + return matmul(multiply(V, rS, s), U, s); } array cholesky_inv( diff --git a/mlx/types/limits.h b/mlx/types/limits.h index 7e0de15bc..5f2b1e9e0 100644 --- a/mlx/types/limits.h +++ b/mlx/types/limits.h @@ -33,6 +33,9 @@ struct numeric_limits { static constexpr float16_t max() { return bits_to_half(0x7BFF); } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } static constexpr float16_t infinity() { return bits_to_half(0x7C00); } @@ -56,6 +59,9 @@ struct numeric_limits { static constexpr bfloat16_t max() { return bits_to_bfloat(0x7F7F); } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } static constexpr bfloat16_t infinity() { return bits_to_bfloat(0x7F80); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 188584174..0b2e66352 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) { } // namespace env template -void set_finfo_limits(double& min, double& max) { +void set_finfo_limits(double& min, double& max, double& eps) { min = numeric_limits::lowest(); max = numeric_limits::max(); + eps = numeric_limits::epsilon(); } finfo::finfo(Dtype dtype) : dtype(dtype) { @@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { throw std::invalid_argument(msg.str()); } if (dtype == float32) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == bfloat16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float64) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == complex64) { this->dtype = float32; - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } } diff --git a/mlx/utils.h b/mlx/utils.h index 19241e4c6..f0aa7c2de 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -65,6 +65,7 @@ struct finfo { Dtype dtype; double min; double max; + double eps; }; /** Holds information about integral types. */ diff --git a/python/src/array.cpp b/python/src/array.cpp index 467bd0fa5..5f8dbe021 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -197,6 +197,13 @@ void init_array(nb::module_& m) { "max", &mx::finfo::max, R"pbdoc(The largest representable number.)pbdoc") + .def_ro( + "eps", + &mx::finfo::eps, + R"pbdoc( + The difference between 1.0 and the next smallest + representable number larger than 1.0. + )pbdoc") .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") .def("__repr__", [](const mx::finfo& f) { std::ostringstream os; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fa5784ea9..792e666d6 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) + self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps) self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) + self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) def test_iinfo(self): diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ffa355c10..a9fe572af 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, M_plus in zip(AB, pinvs): self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3)) + # Test singular matrix + A = mx.array([[4.0, 1.0], [4.0, 1.0]]) + A_plus = mx.linalg.pinv(A, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_plus @ A, A)) + def test_cholesky_inv(self): mx.random.seed(7) From 86984cad68da96624035eb02fa1298065bfd6989 Mon Sep 17 00:00:00 2001 From: hdeng-apple Date: Thu, 24 Apr 2025 21:14:49 +0800 Subject: [PATCH 009/156] Remove static initializers (#2059) * Remove static initializers in device.cpp, load.cpp, pocketfft.h * Remove static initializer InTracing::trace_stack * Remove static initializer of CompilerCache cache * Revert changes in pocketfft.h * Remove duplicate private section of thread_pool() --- mlx/backend/cpu/compiled.cpp | 21 +++++++++++++-------- mlx/device.cpp | 11 +++++++---- mlx/io/load.cpp | 8 ++++++-- mlx/io/load.h | 2 +- mlx/transforms.cpp | 5 ++++- mlx/transforms_impl.h | 12 ++++++------ 6 files changed, 37 insertions(+), 22 deletions(-) diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 9da9c14e8..e389e0df5 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -40,7 +40,10 @@ struct CompilerCache { std::shared_mutex mtx; }; -static CompilerCache cache{}; +static CompilerCache& cache() { + static CompilerCache cache_; + return cache_; +}; // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. @@ -56,14 +59,16 @@ void* compile( const std::string& kernel_name, const std::function& source_builder) { { - std::shared_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::shared_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } } - std::unique_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::unique_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } std::string source_code = source_builder(); @@ -120,10 +125,10 @@ void* compile( } // load library - cache.libs.emplace_back(shared_lib_path); + cache().libs.emplace_back(shared_lib_path); // Load function - void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); + void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " @@ -131,7 +136,7 @@ void* compile( << dlerror(); throw std::runtime_error(msg.str()); } - cache.kernels.insert({kernel_name, fun}); + cache().kernels.insert({kernel_name, fun}); return fun; } diff --git a/mlx/device.cpp b/mlx/device.cpp index e635782e2..20d8675d8 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -5,11 +5,14 @@ namespace mlx::core { -static Device default_device_{ - metal::is_available() ? Device::gpu : Device::cpu}; +Device& mutable_default_device() { + static Device default_device{ + metal::is_available() ? Device::gpu : Device::cpu}; + return default_device; +} const Device& default_device() { - return default_device_; + return mutable_default_device(); } void set_default_device(const Device& d) { @@ -17,7 +20,7 @@ void set_default_device(const Device& d) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } - default_device_ = d; + mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 59d91c007..2f9053f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -335,7 +335,10 @@ ThreadPool& thread_pool() { return pool_; } -ThreadPool ParallelFileReader::thread_pool_{4}; +ThreadPool& ParallelFileReader::thread_pool() { + static ThreadPool thread_pool{4}; + return thread_pool; +} void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { @@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { break; } else { size_t m = batch_size_; - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + futs.emplace_back( + ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data)); data += m; n -= m; offset += m; diff --git a/mlx/io/load.h b/mlx/io/load.h index 138098e82..8b5dd95b6 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -101,7 +101,7 @@ class ParallelFileReader : public Reader { private: static constexpr size_t batch_size_ = 1 << 25; - static ThreadPool thread_pool_; + static ThreadPool& thread_pool(); int fd_; std::string label_; }; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..f9a5de031 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector>& detail::InTracing::trace_stack() { + static std::vector> trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..46851fa3d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; - trace_stack.pop_back(); + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector>& trace_stack(); }; struct RetainGraph { From f0e70afff0789543dea792bcb6223ca658af8a89 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 24 Apr 2025 10:58:29 -0700 Subject: [PATCH 010/156] Fix swift pm load (#2117) --- mlx/backend/metal/device.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 43f82893b..1321fb8d9 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -69,8 +69,8 @@ MTL::Library* try_load_bundle( if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + - lib_name + ".metallib" auto [lib, error] = - load_library_from_path(device, resource_path.c_str()); + lib_name + ".metallib"; + auto [lib, error] = load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } @@ -108,7 +108,7 @@ std::pair load_swiftpm_library( auto bundles = NS::Bundle::allBundles(); for (int i = 0, c = (int)bundles->count(); i < c; i++) { auto bundle = reinterpret_cast(bundles->object(i)); - library = try_load_bundle(device, bundle->resourceURL()); + library = try_load_bundle(device, bundle->resourceURL(), lib_name); if (library != nullptr) { return {library, nullptr}; } From eaf709b83e559079e212699bfc9dd2f939d25c9a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 24 Apr 2025 16:11:07 -0700 Subject: [PATCH 011/156] patch (#2119) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index fe47d96cc..8340e1e8c 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 0 +#define MLX_VERSION_PATCH 1 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From 6b2d5448f230c73ba70996fce188498ce2c83e3f Mon Sep 17 00:00:00 2001 From: 1ndig0 <1090891928@qq.com> Date: Sat, 26 Apr 2025 00:14:28 +0800 Subject: [PATCH 012/156] Fix the error message in `mx.right_shift` and `mx.left_shift` (#2121) * update right_shift and lef_shift * simplify --------- Co-authored-by: Awni Hannun --- mlx/ops.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c2aa4786f..f8308c2d5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4882,8 +4882,9 @@ array bitwise_impl( const array& b, BitwiseBinary::Op op, const std::string& op_name, - const StreamOrDevice& s) { - auto out_type = promote_types(a.dtype(), b.dtype()); + const StreamOrDevice& s, + std::optional out_type_ = std::nullopt) { + auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype()); if (!(issubdtype(out_type, integer) || out_type == bool_)) { std::ostringstream msg; msg << "[" << op_name @@ -4928,12 +4929,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { if (t == bool_) { t = uint8; } - return bitwise_impl( - astype(a, t, s), - astype(b, t, s), - BitwiseBinary::Op::LeftShift, - "left_shift", - s); + return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t); } array operator<<(const array& a, const array& b) { return left_shift(a, b); @@ -4949,7 +4945,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { astype(b, t, s), BitwiseBinary::Op::RightShift, "right_shift", - s); + s, + t); } array operator>>(const array& a, const array& b) { return right_shift(a, b); From 99b986885997e578c0f04cb9975778a800012b08 Mon Sep 17 00:00:00 2001 From: charan-003 <85248228+charan-003@users.noreply.github.com> Date: Fri, 25 Apr 2025 13:18:30 -0600 Subject: [PATCH 013/156] Clarify dimension notation in conv1d, conv2d, and conv3d docstrings (#2123) * Clarify dimension notation in conv1d, conv2d, and conv3d docstrings * Updating transposed convs in conv1d, conv2d, and conv3d --------- Co-authored-by: Sai Charan Arvapally --- python/src/ops.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 60b6188ed..a1e77d681 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3455,8 +3455,8 @@ void init_ops(nb::module_& m) { 1D convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. @@ -3514,7 +3514,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3586,7 +3586,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3619,8 +3619,8 @@ void init_ops(nb::module_& m) { 1D transposed convolution over an input with several channels Args: - input (array): Input array of shape ``(N, H, C_in)``. - weight (array): Weight array of shape ``(C_out, H, C_in)``. + input (array): Input array of shape ``(N, L, C_in)``. + weight (array): Weight array of shape ``(C_out, K, C_in)``. stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. @@ -3697,7 +3697,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. @@ -3783,7 +3783,7 @@ void init_ops(nb::module_& m) { Args: input (array): Input array of shape ``(N, D, H, W, C_in)``. - weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``. + weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. stride (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. From 167b759a389cc5f826efb3e8528057ef66210f1e Mon Sep 17 00:00:00 2001 From: hdeng-apple Date: Tue, 29 Apr 2025 22:26:05 +0800 Subject: [PATCH 014/156] Fix typos (#2136) --- mlx/array.h | 2 +- mlx/backend/metal/kernels/fft/readwrite.h | 2 +- .../metal/kernels/steel/attn/kernels/steel_attention.h | 4 ++-- mlx/backend/metal/kernels/steel/attn/loader.h | 4 ++-- .../metal/kernels/steel/conv/kernels/steel_conv_general.h | 2 +- mlx/backend/metal/kernels/steel/gemm/loader.h | 2 +- mlx/backend/no_cpu/compiled.cpp | 2 +- mlx/ops.h | 2 +- mlx/random.cpp | 4 ++-- python/src/random.cpp | 2 +- 10 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index 66a4702a6..d9fcfc58e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -356,7 +356,7 @@ class array { } enum Status { - // The ouptut of a computation which has not been scheduled. + // The output of a computation which has not been scheduled. // For example, the status of `x` in `auto x = a + b`. unscheduled, diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index ab699e136..f6724820d 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adajcent threads for optimal performance. +coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 2e27ea06f..34d5bf58a 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -95,7 +95,7 @@ template < Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Seqeunce + tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch @@ -106,7 +106,7 @@ template < O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Seqeunce + tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h index 2849c00f1..7ec798146 100644 --- a/mlx/backend/metal/kernels/steel/attn/loader.h +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -240,7 +240,7 @@ struct BlockLoaderT { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index e4b662cd3..8253638f1 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -141,7 +141,7 @@ implicit_gemm_conv_2d_general( // Store results to device memory { - // Adjust for simdgroup and thread locatio + // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index 3f084d8ec..d421b2d1f 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/mlx/backend/no_cpu/compiled.cpp b/mlx/backend/no_cpu/compiled.cpp index c1c42c735..2eeddab47 100644 --- a/mlx/backend/no_cpu/compiled.cpp +++ b/mlx/backend/no_cpu/compiled.cpp @@ -18,7 +18,7 @@ void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { throw std::runtime_error( - "[Compiled::eval_cpu] CPU compialtion not supported on the platform."); + "[Compiled::eval_cpu] CPU compilation not supported on the platform."); } } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 12e896af6..af3cdb5bd 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -569,7 +569,7 @@ inline array std(const array& a, StreamOrDevice s = {}) { return std(a, false, 0, to_stream(s)); } -/** Computes the standard deviatoin of the elements of an array along the given +/** Computes the standard deviation of the elements of an array along the given * axes */ array std( const array& a, diff --git a/mlx/random.cpp b/mlx/random.cpp index d6ce5bb0e..89a027b17 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -223,7 +223,7 @@ array multivariate_normal( auto n = mean.shape(-1); - // Check shapes comatibility of mean and cov + // Check shapes compatibility of mean and cov if (cov.shape(-1) != cov.shape(-2)) { throw std::invalid_argument( "[multivariate_normal] last two dimensions of cov must be equal."); @@ -402,7 +402,7 @@ array categorical( if (broadcast_shapes(shape, reduced_shape) != shape) { std::ostringstream msg; msg << "[categorical] Requested shape " << shape - << " is not broadcast compatable with reduced logits shape" + << " is not broadcast compatible with reduced logits shape" << reduced_shape << "."; throw std::invalid_argument(msg.str()); } diff --git a/python/src/random.cpp b/python/src/random.cpp index e9c0a87fc..22b706174 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -422,7 +422,7 @@ void init_random(nb::module_& parent_module) { axis (int, optional): The axis which specifies the distribution. Default: ``-1``. shape (list(int), optional): The shape of the output. This must - be broadcast compatable with ``logits.shape`` with the ``axis`` + be broadcast compatible with ``logits.shape`` with the ``axis`` dimension removed. Default: ``None`` num_samples (int, optional): The number of samples to draw from each of the categorical distributions in ``logits``. The output will have From b36dd472bb11caa233b007270ab08b6a7b1fc91b Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:30:36 -0400 Subject: [PATCH 015/156] return library if it is successfully loaded (#2131) --- mlx/backend/metal/device.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 1321fb8d9..cb851b57e 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -166,6 +166,7 @@ MTL::Library* load_library( << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // We have been given a path so try to load from lib_path / lib_name.metallib @@ -178,6 +179,7 @@ MTL::Library* load_library( << "> with error " << error->localizedDescription()->utf8String(); throw std::runtime_error(msg.str()); } + return lib; } // Try to load the colocated library From 7bb063bcb3000cf9f57078c114fd385577074c57 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 29 Apr 2025 13:03:09 -0700 Subject: [PATCH 016/156] Enable vjp for quantized scale and bias (#2129) * Enable vjp for quantized scale and bias * higher tol --- mlx/primitives.cpp | 30 ++++++++++++++++++++++++++++-- python/tests/test_quantized.py | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3d36f0881..7288a4885 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3056,6 +3056,7 @@ std::vector QuantizedMatmul::vjp( std::vector vjps; // We rely on the fact that w is always 2D so transpose is simple + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3071,9 +3072,34 @@ std::vector QuantizedMatmul::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet."); + "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto fc = flatten(cotangents[0], 0, -2, stream()); + auto fx = flatten(primals[0], 0, -2, stream()); + auto dw = transpose_ + ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) + : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); + dsb = unflatten(dw, -1, {-1, group_size_}, stream()); + } + if (arg == 3) { + // biases + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + // scales + auto s = stream(); + auto wq = dequantize( + primals[1], + ones_like(primals[2], stream()), + zeros_like(primals[3], stream()), + group_size_, + bits_, + stream()); + wq = unflatten(wq, -1, {-1, group_size_}, stream()); + vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); + } } } return vjps; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index eeefcd94f..60ab421c6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_vjp_scales_biases(self): + mx.random.seed(0) + x = mx.random.normal(shape=(2, 2, 512)) + w = mx.random.normal(shape=(512, 512)) + wq, s, b = mx.quantize(w, bits=4, group_size=64) + + def mm(sb, x, wq): + return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum() + + params = (s, b) + dparams = mx.grad(mm)((s, b), x, wq) + + eps = 8e-3 + # numerical grad check with a few indices + indices = [(0, 0), (11, 4), (22, 7)] + for idx in indices: + for p in [0, 1]: + params[p][idx] += eps + out_up = mm(params, x, wq) + params[p][idx] -= 2 * eps + out_down = mm(params, x, wq) + params[p][idx] += eps + num_ds = (out_up - out_down) / (2 * eps) + self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) + if __name__ == "__main__": unittest.main() From bb6565ef14e54da7ab794a3b52e1000a7c744331 Mon Sep 17 00:00:00 2001 From: Aashiq Dheeraj Date: Wed, 30 Apr 2025 01:13:45 -0400 Subject: [PATCH 017/156] add fftshift and ifftshift fft helpers (#2135) * add fftshift and ifftshift fft helpers * address comments * axes have to be iterable * fix fp error in roll + add test --------- Co-authored-by: Aashiq Dheeraj --- docs/src/python/fft.rst | 2 ++ mlx/fft.cpp | 71 ++++++++++++++++++++++++++++++++++++++++ mlx/fft.h | 18 ++++++++++ mlx/ops.cpp | 7 ++-- python/src/fft.cpp | 51 +++++++++++++++++++++++++++++ python/tests/test_fft.py | 62 +++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 5 +++ tests/fft_tests.cpp | 58 ++++++++++++++++++++++++++++++++ tests/ops_tests.cpp | 3 ++ 9 files changed, 275 insertions(+), 2 deletions(-) diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 9e4be084b..36d9d7838 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,3 +20,5 @@ FFT irfft2 rfftn irfftn + fftshift + ifftshift diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 02878af9c..6510faec1 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -184,8 +184,79 @@ array irfftn( StreamOrDevice s /* = {} */) { return fft_impl(a, axes, true, true, s); } + array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + shifts.push_back(a.shape(axis) / 2); + } + + return roll(a, shifts, axes, s); +} + +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + int size = a.shape(axis); + shifts.push_back(-(size / 2)); + } + + return roll(a, shifts, axes, s); +} + +// Default versions that operate on all axes +array fftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return fftshift(a, axes, s); +} + +array ifftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return ifftshift(a, axes, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index 2f02da73b..163e06b80 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -145,5 +145,23 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, s); } +/** Shift the zero-frequency component to the center of the spectrum. */ +array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); } // namespace mlx::core::fft diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f8308c2d5..e7abe12db 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -5025,8 +5025,11 @@ array roll( } auto sh = shift[i]; - auto split_index = - (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + auto size = a.shape(ax); + if (size == 0) { + continue; // skip rolling this axis if it has size 0 + } + auto split_index = (sh < 0) ? (-sh) % size : size - sh % size; auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 5ad4702e2..026f8139d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -459,4 +459,55 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + m.def( + "fftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::fftshift(a, axes.value(), s); + } else { + return mx::fft::fftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Shift the zero-frequency component to the center of the spectrum. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the shift. + If ``None``, shift all axes. + + Returns: + array: The shifted array with the same shape as the input. + )pbdoc"); + m.def( + "ifftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::ifftshift(a, axes.value(), s); + } else { + return mx::fft::ifftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, + the behavior differs for odd-length axes. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the inverse shift. + If ``None``, shift all axes. + + Returns: + array: The inverse-shifted array with the same shape as the input. + )pbdoc"); } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index c887cd968..f644944c7 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.fft.irfftn(x) + def test_fftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c) + + def test_ifftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c) + + def test_fftshift_errors(self): + # Test invalid axes + x = mx.array(np.random.rand(4, 4).astype(np.float32)) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[2]) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[-3]) + + # Test empty array + x = mx.array([]) + self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 47fec3167..d840eac7d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2961,6 +2961,11 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_roll_errors(self): + x = mx.array([]) + result = mx.roll(x, [0], [0]) + self.assertTrue(mx.array_equal(result, x)) + def test_real_imag(self): x = mx.random.uniform(shape=(4, 4)) out = mx.real(x) diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index c04dda1d5..0db3999c8 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -308,3 +308,61 @@ TEST_CASE("test fft grads") { .second; CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } + +TEST_CASE("test fftshift and ifftshift") { + // Test 1D array with even length + auto x = arange(8); + auto y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + // print y + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test 1D array with odd length + x = arange(7); + y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item()); + + // Test 2D array + x = reshape(arange(16), {4, 4}); + y = fft::fftshift(x); + auto expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test with specific axes + y = fft::fftshift(x, {0}); + expected = + array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + y = fft::fftshift(x, {1}); + expected = + array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test ifftshift (inverse operation) + x = arange(8); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test ifftshift with odd length (different from fftshift) + x = arange(7); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item()); + + // Test 2D ifftshift + x = reshape(arange(16), {4, 4}); + y = fft::ifftshift(x); + expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test error cases + CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); +} diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c4f319d46..5e2bae5a0 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3859,6 +3859,9 @@ TEST_CASE("test roll") { y = roll(x, {1, 2}, {0, 1}); CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); + + y = roll(array({}), 0, 0); + CHECK(array_equal(y, array({})).item()); } TEST_CASE("test contiguous") { From 87720a8908a4f3fb90891659768ce2fb7cf98fa1 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 30 Apr 2025 22:04:07 +0900 Subject: [PATCH 018/156] Fix building with uv (#2141) --- .gitignore | 1 + MANIFEST.in | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index e748ee2bf..43629548d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +uv.lock # vim *.swp diff --git a/MANIFEST.in b/MANIFEST.in index 9faafee45..d0daeb7ae 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,6 @@ include CMakeLists.txt +include mlx.pc.in recursive-include mlx/ * +include cmake/* include python/src/* include python/mlx/py.typed # support type hinting as in PEP-561 From f1606486d252124ab60cd45dc8761fcbfa6e3dc7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Apr 2025 09:08:17 -0700 Subject: [PATCH 019/156] Generalize gpu backend (#2138) * generalize gpu backend * fix no_gpu build * fix no_gpu build * generalize gpu backend --- mlx/CMakeLists.txt | 4 +- mlx/backend/cpu/CMakeLists.txt | 3 +- mlx/backend/cpu/available.cpp | 11 ++ mlx/backend/cpu/available.h | 9 ++ mlx/backend/gpu/available.h | 9 ++ .../{metal/metal_impl.h => gpu/eval.h} | 7 +- mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/allocator.cpp | 1 - mlx/backend/metal/device.cpp | 41 ------ mlx/backend/metal/device.h | 2 + mlx/backend/metal/eval.cpp | 102 +++++++++++++++ mlx/backend/metal/event.cpp | 1 - mlx/backend/metal/fence.cpp | 1 - mlx/backend/metal/metal.cpp | 117 ++++++------------ mlx/backend/metal/metal.h | 3 +- mlx/backend/metal/no_metal.cpp | 22 ++++ mlx/backend/metal/resident.cpp | 1 - mlx/backend/no_cpu/CMakeLists.txt | 3 +- mlx/backend/no_cpu/available.cpp | 11 ++ .../{no_metal => no_gpu}/CMakeLists.txt | 2 +- .../{no_metal => no_gpu}/allocator.cpp | 4 +- .../{no_metal => no_gpu}/apple_memory.h | 0 mlx/backend/no_gpu/eval.cpp | 28 +++++ mlx/backend/{no_metal => no_gpu}/event.cpp | 0 mlx/backend/{no_metal => no_gpu}/fence.cpp | 0 .../{no_metal => no_gpu}/linux_memory.h | 0 .../{no_metal => no_gpu}/primitives.cpp | 0 mlx/backend/no_metal/metal.cpp | 43 ------- mlx/device.cpp | 21 +++- mlx/device.h | 2 + mlx/scheduler.cpp | 11 +- mlx/scheduler.h | 7 +- mlx/transforms.cpp | 8 +- 33 files changed, 275 insertions(+), 200 deletions(-) create mode 100644 mlx/backend/cpu/available.cpp create mode 100644 mlx/backend/cpu/available.h create mode 100644 mlx/backend/gpu/available.h rename mlx/backend/{metal/metal_impl.h => gpu/eval.h} (63%) create mode 100644 mlx/backend/metal/eval.cpp create mode 100644 mlx/backend/metal/no_metal.cpp create mode 100644 mlx/backend/no_cpu/available.cpp rename mlx/backend/{no_metal => no_gpu}/CMakeLists.txt (82%) rename mlx/backend/{no_metal => no_gpu}/allocator.cpp (96%) rename mlx/backend/{no_metal => no_gpu}/apple_memory.h (100%) create mode 100644 mlx/backend/no_gpu/eval.cpp rename mlx/backend/{no_metal => no_gpu}/event.cpp (100%) rename mlx/backend/{no_metal => no_gpu}/fence.cpp (100%) rename mlx/backend/{no_metal => no_gpu}/linux_memory.h (100%) rename mlx/backend/{no_metal => no_gpu}/primitives.cpp (100%) delete mode 100644 mlx/backend/no_metal/metal.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index abf46a7d5..465954d6f 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -49,5 +49,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 152f33b17..96b3f1313 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -40,7 +40,8 @@ add_dependencies(mlx cpu_compiled_preamble) target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp diff --git a/mlx/backend/cpu/available.cpp b/mlx/backend/cpu/available.cpp new file mode 100644 index 000000000..0449d49b9 --- /dev/null +++ b/mlx/backend/cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/available.h b/mlx/backend/cpu/available.h new file mode 100644 index 000000000..1df95def2 --- /dev/null +++ b/mlx/backend/cpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cpu { + +bool is_available(); + +} // namespace mlx::core::cpu diff --git a/mlx/backend/gpu/available.h b/mlx/backend/gpu/available.h new file mode 100644 index 000000000..476c7acf2 --- /dev/null +++ b/mlx/backend/gpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +bool is_available(); + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/metal_impl.h b/mlx/backend/gpu/eval.h similarity index 63% rename from mlx/backend/metal/metal_impl.h rename to mlx/backend/gpu/eval.h index 9ca8d2f80..f646c2ec9 100644 --- a/mlx/backend/metal/metal_impl.h +++ b/mlx/backend/gpu/eval.h @@ -8,14 +8,11 @@ #include "mlx/array.h" #include "mlx/stream.h" -namespace mlx::core::metal { +namespace mlx::core::gpu { void new_stream(Stream stream); - -std::unique_ptr> new_scoped_memory_pool(); - void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); -} // namespace mlx::core::metal +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 332c560f8..d0c872451 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -93,6 +93,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0a69dd261..5d8bd90d5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/resident.h" #include "mlx/memory.h" diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index cb851b57e..ebc3cc77f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -4,15 +4,12 @@ #include #include -#include - #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" @@ -772,42 +769,4 @@ std::unique_ptr> new_scoped_memory_pool() { NS::AutoreleasePool::alloc()->init(), dtor); } -void new_stream(Stream stream) { - if (stream.device == mlx::core::Device::gpu) { - device(stream.device).new_queue(stream.index); - } -} - -const std::unordered_map>& -device_info() { - auto init_device_info = []() - -> std::unordered_map> { - auto pool = new_scoped_memory_pool(); - auto raw_device = device(default_device()).mtl_device(); - auto name = std::string(raw_device->name()->utf8String()); - auto arch = std::string(raw_device->architecture()->name()->utf8String()); - - size_t memsize = 0; - size_t length = sizeof(memsize); - sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); - - size_t rsrc_limit = 0; - sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); - if (rsrc_limit == 0) { - rsrc_limit = 499000; - } - - return { - {"device_name", name}, - {"architecture", arch}, - {"max_buffer_length", raw_device->maxBufferLength()}, - {"max_recommended_working_set_size", - raw_device->recommendedMaxWorkingSetSize()}, - {"memory_size", memsize}, - {"resource_limit", rsrc_limit}}; - }; - static auto device_info_ = init_device_info(); - return device_info_; -} - } // namespace mlx::core::metal diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d60635e39..26c9a0a28 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -266,4 +266,6 @@ class Device { Device& device(mlx::core::Device); +std::unique_ptr> new_scoped_memory_pool(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp new file mode 100644 index 000000000..49783200a --- /dev/null +++ b/mlx/backend/metal/eval.cpp @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream stream) { + if (stream.device == mlx::core::Device::gpu) { + metal::device(stream.device).new_queue(stream.index); + } +} + +inline void check_error(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } +} + +void eval(array& arr) { + auto pool = metal::new_scoped_memory_pool(); + auto s = arr.primitive().stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + + if (d.command_buffer_needs_commit(s.index)) { + d.end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + scheduler::notify_task_completion(s); + check_error(cbuf); + }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + check_error(cbuf); + }); + } +} + +void finalize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + d.end_encoding(s.index); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); +} + +void synchronize(Stream s) { + auto pool = metal::new_scoped_memory_pool(); + auto& d = metal::device(s.device); + auto cb = d.get_command_buffer(s.index); + cb->retain(); + d.end_encoding(s.index); + d.commit_command_buffer(s.index); + cb->waitUntilCompleted(); + check_error(cb); + cb->release(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 246d6bcc5..eb7f1b58a 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -2,7 +2,6 @@ #include "mlx/event.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" namespace mlx::core { diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index e784d34ae..d4a88d983 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/fence.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal_impl.h" #include "mlx/scheduler.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a9a1bc4f6..888207322 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. #include +#include + #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" -#include "mlx/scheduler.h" -#include "mlx/utils.h" namespace mlx::core::metal { @@ -13,85 +13,6 @@ bool is_available() { return true; } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - -void eval(array& arr) { - auto pool = new_scoped_memory_pool(); - auto s = arr.primitive().stream(); - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - - auto outputs = arr.outputs(); - { - // If the array is a tracer hold a reference - // to its inputs so they don't get donated - std::vector inputs; - if (arr.is_tracer()) { - inputs = arr.inputs(); - } - - debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - arr.primitive().eval_gpu(arr.inputs(), outputs); - } - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - - if (d.command_buffer_needs_commit(s.index)) { - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); - } -} - -void finalize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - d.end_encoding(s.index); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); - d.commit_command_buffer(s.index); - d.get_command_buffer(s.index); -} - -void synchronize(Stream s) { - auto pool = new_scoped_memory_pool(); - auto& d = metal::device(s.device); - auto cb = d.get_command_buffer(s.index); - cb->retain(); - d.end_encoding(s.index); - d.commit_command_buffer(s.index); - cb->waitUntilCompleted(); - check_error(cb); - cb->release(); -} - void start_capture(std::string path, id object) { auto pool = new_scoped_memory_pool(); @@ -128,4 +49,36 @@ void stop_capture() { manager->stopCapture(); } +const std::unordered_map>& +device_info() { + auto init_device_info = []() + -> std::unordered_map> { + auto pool = new_scoped_memory_pool(); + auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); + auto arch = std::string(raw_device->architecture()->name()->utf8String()); + + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + + size_t rsrc_limit = 0; + sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0); + if (rsrc_limit == 0) { + rsrc_limit = 499000; + } + + return { + {"device_name", name}, + {"architecture", arch}, + {"max_buffer_length", raw_device->maxBufferLength()}, + {"max_recommended_working_set_size", + raw_device->recommendedMaxWorkingSetSize()}, + {"memory_size", memsize}, + {"resource_limit", rsrc_limit}}; + }; + static auto device_info_ = init_device_info(); + return device_info_; +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d162007d1..af2995b63 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -2,11 +2,10 @@ #pragma once +#include #include #include -#include "mlx/array.h" - namespace mlx::core::metal { /* Check if the Metal backend is available. */ diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp new file mode 100644 index 000000000..b6142b280 --- /dev/null +++ b/mlx/backend/metal/no_metal.cpp @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal/metal.h" + +namespace mlx::core::metal { + +bool is_available() { + return false; +} + +void start_capture(std::string) {} +void stop_capture() {} + +const std::unordered_map>& +device_info() { + throw std::runtime_error( + "[metal::device_info] Cannot get device info without metal backend"); +}; + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 0a9e1b861..798824c2f 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/resident.h" -#include "mlx/backend/metal/metal_impl.h" namespace mlx::core::metal { diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index e1524ec63..2e6960829 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,6 +1,7 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/no_cpu/available.cpp b/mlx/backend/no_cpu/available.cpp new file mode 100644 index 000000000..04c1bac8e --- /dev/null +++ b/mlx/backend/no_cpu/available.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cpu/available.h" + +namespace mlx::core::cpu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cpu diff --git a/mlx/backend/no_metal/CMakeLists.txt b/mlx/backend/no_gpu/CMakeLists.txt similarity index 82% rename from mlx/backend/no_metal/CMakeLists.txt rename to mlx/backend/no_gpu/CMakeLists.txt index 962ceecb7..78e15ac69 100644 --- a/mlx/backend/no_metal/CMakeLists.txt +++ b/mlx/backend/no_gpu/CMakeLists.txt @@ -3,5 +3,5 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp similarity index 96% rename from mlx/backend/no_metal/allocator.cpp rename to mlx/backend/no_gpu/allocator.cpp index a8b260b6b..320d1a267 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -6,9 +6,9 @@ #include "mlx/allocator.h" #ifdef __APPLE__ -#include "mlx/backend/no_metal/apple_memory.h" +#include "mlx/backend/no_gpu/apple_memory.h" #elif defined(__linux__) -#include "mlx/backend/no_metal/linux_memory.h" +#include "mlx/backend/no_gpu/linux_memory.h" #else size_t get_memory_size() { return 0; diff --git a/mlx/backend/no_metal/apple_memory.h b/mlx/backend/no_gpu/apple_memory.h similarity index 100% rename from mlx/backend/no_metal/apple_memory.h rename to mlx/backend/no_gpu/apple_memory.h diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp new file mode 100644 index 000000000..8bff86a98 --- /dev/null +++ b/mlx/backend/no_gpu/eval.cpp @@ -0,0 +1,28 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" + +namespace mlx::core::gpu { + +bool is_available() { + return false; +} + +void new_stream(Stream) {} + +void eval(array&) { + throw std::runtime_error("[gpu::eval] GPU backend is not available"); +} + +void finalize(Stream) { + throw std::runtime_error("[gpu::finalize] GPU backend is not available"); +} + +void synchronize(Stream) { + throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/no_metal/event.cpp b/mlx/backend/no_gpu/event.cpp similarity index 100% rename from mlx/backend/no_metal/event.cpp rename to mlx/backend/no_gpu/event.cpp diff --git a/mlx/backend/no_metal/fence.cpp b/mlx/backend/no_gpu/fence.cpp similarity index 100% rename from mlx/backend/no_metal/fence.cpp rename to mlx/backend/no_gpu/fence.cpp diff --git a/mlx/backend/no_metal/linux_memory.h b/mlx/backend/no_gpu/linux_memory.h similarity index 100% rename from mlx/backend/no_metal/linux_memory.h rename to mlx/backend/no_gpu/linux_memory.h diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp similarity index 100% rename from mlx/backend/no_metal/primitives.cpp rename to mlx/backend/no_gpu/primitives.cpp diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp deleted file mode 100644 index ef9af8800..000000000 --- a/mlx/backend/no_metal/metal.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" -namespace mlx::core::metal { - -bool is_available() { - return false; -} - -void new_stream(Stream) {} - -std::unique_ptr> new_scoped_memory_pool() { - return nullptr; -} - -void eval(array&) { - throw std::runtime_error( - "[metal::eval] Cannot eval on GPU without metal backend"); -} - -void finalize(Stream) { - throw std::runtime_error( - "[metal::finalize] Cannot finalize GPU without metal backend"); -} - -void synchronize(Stream) { - throw std::runtime_error( - "[metal::synchronize] Cannot synchronize GPU without metal backend"); -} - -void start_capture(std::string) {} -void stop_capture() {} - -const std::unordered_map>& -device_info() { - throw std::runtime_error( - "[metal::device_info] Cannot get device info without metal backend"); -}; - -} // namespace mlx::core::metal diff --git a/mlx/device.cpp b/mlx/device.cpp index 20d8675d8..ec17a509a 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -1,13 +1,15 @@ // Copyright © 2023 Apple Inc. +#include + +#include "mlx/backend/cpu/available.h" +#include "mlx/backend/gpu/available.h" #include "mlx/device.h" -#include "mlx/backend/metal/metal.h" namespace mlx::core { Device& mutable_default_device() { - static Device default_device{ - metal::is_available() ? Device::gpu : Device::cpu}; + static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; return default_device; } @@ -16,7 +18,7 @@ const Device& default_device() { } void set_default_device(const Device& d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } @@ -31,4 +33,15 @@ bool operator!=(const Device& lhs, const Device& rhs) { return !(lhs == rhs); } +bool is_available(const Device& d) { + switch (d.type) { + case Device::cpu: + return cpu::is_available(); + case Device::gpu: + return gpu::is_available(); + } + // appease compiler + return false; +} + } // namespace mlx::core diff --git a/mlx/device.h b/mlx/device.h index a11e40e9d..80c624c1c 100644 --- a/mlx/device.h +++ b/mlx/device.h @@ -26,4 +26,6 @@ void set_default_device(const Device& d); bool operator==(const Device& lhs, const Device& rhs); bool operator!=(const Device& lhs, const Device& rhs); +bool is_available(const Device& d); + } // namespace mlx::core diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index 7bd128c10..b19f6434a 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -1,12 +1,13 @@ // Copyright © 2023 Apple Inc. #include "mlx/scheduler.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/gpu/eval.h" namespace mlx::core { Stream default_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[default_stream] Cannot get gpu stream without gpu backend."); } @@ -14,7 +15,7 @@ Stream default_stream(Device d) { } void set_default_stream(Stream s) { - if (!metal::is_available() && s.device == Device::gpu) { + if (!gpu::is_available() && s.device == Device::gpu) { throw std::invalid_argument( "[set_default_stream] Cannot set gpu stream without gpu backend."); } @@ -26,7 +27,7 @@ Stream get_stream(int index) { } Stream new_stream(Device d) { - if (!metal::is_available() && d == Device::gpu) { + if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( "[new_stream] Cannot make gpu stream without gpu backend."); } @@ -44,7 +45,7 @@ void synchronize(Stream s) { scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); }); f.wait(); } else { - metal::synchronize(s); + gpu::synchronize(s); } } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index b2c6b842b..877fdd5f6 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -8,8 +8,7 @@ #include #include -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/device.h" #include "mlx/stream.h" @@ -67,7 +66,7 @@ struct StreamThread { class Scheduler { public: Scheduler() : n_active_tasks_(0) { - if (metal::is_available()) { + if (is_available(Device::gpu)) { default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); } default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); @@ -83,7 +82,7 @@ class Scheduler { streams_.emplace_back(streams_.size(), d); if (d == Device::gpu) { threads_.push_back(nullptr); - metal::new_stream(streams_.back()); + gpu::new_stream(streams_.back()); } else { threads_.push_back(new StreamThread{}); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f9a5de031..2d9942eda 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -10,7 +10,7 @@ #include #include "mlx/backend/cpu/eval.h" -#include "mlx/backend/metal/metal_impl.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/fence.h" #include "mlx/memory.h" #include "mlx/ops.h" @@ -218,7 +218,7 @@ array eval_impl(std::vector outputs, bool async) { } if (arr.primitive().device() == Device::gpu) { - metal::eval(arr); + gpu::eval(arr); } else { cpu::eval(arr); } @@ -229,7 +229,7 @@ array eval_impl(std::vector outputs, bool async) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { - metal::finalize(e.stream()); + gpu::finalize(e.stream()); } } scheduler::wait_for_one(); @@ -267,7 +267,7 @@ array eval_impl(std::vector outputs, bool async) { auto s = e.stream(); e.signal(s); if (s.device == Device::gpu) { - metal::finalize(s); + gpu::finalize(s); } } From aa5d84f102e8e2fff0f3db3f1d23b61fb1e1a2d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Apr 2025 09:08:29 -0700 Subject: [PATCH 020/156] Allow quant layer to be unfrozen (#2142) --- python/mlx/nn/layers/quantized.py | 6 ------ python/tests/test_nn.py | 9 ++++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..2d6dc0882 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -193,12 +193,6 @@ class QuantizedLinear(Module): # Freeze this model's parameters self.freeze() - def unfreeze(self, *args, **kwargs): - """Wrap unfreeze so that we unfreeze any layers we might contain but - our parameters will remain frozen.""" - super().unfreeze(*args, **kwargs) - self.freeze(recurse=False) - def _extra_repr(self): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9cfa25dae..826d53d96 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -8,7 +8,7 @@ import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np -from mlx.utils import tree_flatten, tree_map +from mlx.utils import tree_flatten, tree_map, tree_reduce class TestBase(mlx_tests.MLXTestCase): @@ -198,6 +198,13 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + def test_quantize_freeze(self): + lin = nn.Linear(512, 512) + qlin = lin.to_quantized() + qlin.unfreeze(keys=["scales"]) + size = tree_reduce(lambda acc, p: acc + p.size, qlin.trainable_parameters(), 0) + self.assertTrue(size > 0) + def test_grad_of_module(self): class Model(nn.Module): def __init__(self): From ea890d87103b154d9d8f6c48fe07ba3eebe89e37 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 1 May 2025 01:08:39 +0900 Subject: [PATCH 021/156] Remove metal-only tests (#2139) --- tests/CMakeLists.txt | 2 +- tests/{metal_tests.cpp => gpu_tests.cpp} | 31 +++++++++--------------- 2 files changed, 12 insertions(+), 21 deletions(-) rename tests/{metal_tests.cpp => gpu_tests.cpp} (95%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be4479e70..cf0ba3d5d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,7 +10,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) if(MLX_BUILD_METAL) - set(METAL_TEST_SOURCES metal_tests.cpp) + set(METAL_TEST_SOURCES gpu_tests.cpp) endif() include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) diff --git a/tests/metal_tests.cpp b/tests/gpu_tests.cpp similarity index 95% rename from tests/metal_tests.cpp rename to tests/gpu_tests.cpp index 7aabdf36d..f0ef969cf 100644 --- a/tests/metal_tests.cpp +++ b/tests/gpu_tests.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "doctest/doctest.h" -#include "mlx/backend/metal/allocator.h" -#include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/metal.h" +#include "doctest/doctest.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -13,13 +10,7 @@ using namespace mlx::core; static const std::array types = {bool_, uint32, int32, int64, float32}; -TEST_CASE("test metal device") { - // Make sure the device and library can load - CHECK(metal::is_available()); - auto& device = metal::device(Device::gpu); -} - -TEST_CASE("test metal arange") { +TEST_CASE("test gpu arange") { for (auto t : types) { if (t == bool_) { continue; @@ -34,7 +25,7 @@ TEST_CASE("test metal arange") { } } -TEST_CASE("test metal full") { +TEST_CASE("test gpu full") { for (auto t : types) { auto out_cpu = full({4, 4}, 2, t, Device::cpu); auto out_gpu = full({4, 4}, 2, t, Device::gpu); @@ -63,7 +54,7 @@ TEST_CASE("test metal full") { } } -TEST_CASE("test metal astype") { +TEST_CASE("test gpu astype") { array x = array({-4, -3, -2, -1, 0, 1, 2, 3}); // Check all types work for (auto t : types) { @@ -80,7 +71,7 @@ TEST_CASE("test metal astype") { } } -TEST_CASE("test metal reshape") { +TEST_CASE("test gpu reshape") { array x = array({0, 1, 2, 3, 4, 5, 6, 7}); auto out_cpu = reshape(x, {2, 2, 2}); auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu); @@ -96,7 +87,7 @@ TEST_CASE("test metal reshape") { CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); } -TEST_CASE("test metal reduce") { +TEST_CASE("test gpu reduce") { { array a(true); CHECK_EQ(all(a, Device::gpu).item(), true); @@ -190,7 +181,7 @@ TEST_CASE("test metal reduce") { } } -TEST_CASE("test metal binary ops") { +TEST_CASE("test gpu binary ops") { // scalar-scalar { array a(2.0f); @@ -338,7 +329,7 @@ TEST_CASE("test metal binary ops") { } } -TEST_CASE("test metal unary ops") { +TEST_CASE("test gpu unary ops") { // contiguous { array x({-1.0f, 0.0f, 1.0f}); @@ -392,7 +383,7 @@ TEST_CASE("test metal unary ops") { } } -TEST_CASE("test metal random") { +TEST_CASE("test gpu random") { { auto key = random::key(0); auto x = random::bits({}, 4, key, Device::gpu); @@ -415,7 +406,7 @@ TEST_CASE("test metal random") { } } -TEST_CASE("test metal matmul") { +TEST_CASE("test gpu matmul") { { auto a = ones({2, 2}); auto b = ones({2, 2}); @@ -440,7 +431,7 @@ TEST_CASE("test metal matmul") { } } -TEST_CASE("test metal validation") { +TEST_CASE("test gpu validation") { // Run this test with Metal validation enabled // METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \ // -tc="test metal validation" \ From e496c5a4b4d06e07ad204bdf3226df4f1bc3e259 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 30 Apr 2025 09:28:56 -0700 Subject: [PATCH 022/156] fix integer overflow in qmm (#2143) --- mlx/backend/metal/kernels/quantized.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b2b0d8d8f..ba4fb2426 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); From a3a632d567912e369eaccd9690231deff40973a9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 1 May 2025 12:56:09 -0700 Subject: [PATCH 023/156] Fix the launcher when ran locally (#2147) --- python/mlx/distributed_run.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 9c946005b..404ecc349 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): # Repeat the stdout and stderr to the local machine to_read = [p.stdout.fileno(), p.stderr.fileno()] - to_write = [p.stdin.fileno()] + to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] pidfile = "" stdin_buffer = b"" + stdout_buffer = b"" + stderr_buffer = b"" while p.poll() is None: try: stdin_buffer += input_queue.get_nowait() @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): pass rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: - is_stdout = fd == p.stdout.fileno() - outfile = sys.stdout if is_stdout else sys.stderr msg = os.read(fd, 8192).decode(errors="ignore") # Fetch the PID file first if we haven't already @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): pidfile, *msg = msg.split("\n", maxsplit=1) msg = msg[0] if msg else "" - outfile.write(msg) - outfile.flush() + is_stdout = fd == p.stdout.fileno() + if is_stdout: + stdout_buffer += msg.encode() + else: + stderr_buffer += msg.encode() for fd in wlist: - if len(stdin_buffer) > 0: + if fd == p.stdin.fileno() and len(stdin_buffer) > 0: n = os.write(fd, stdin_buffer) stdin_buffer = stdin_buffer[n:] + elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: + n = os.write(fd, stdout_buffer) + stdout_buffer = stdout_buffer[n:] + elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: + n = os.write(fd, stderr_buffer) + stderr_buffer = stderr_buffer[n:] if stop: p.terminate() break From 9daa6b003f548569e9186502d5acb9b74b91bcbe Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 1 May 2025 15:02:02 -0700 Subject: [PATCH 024/156] fix shapeless export (#2148) --- mlx/export.cpp | 3 +++ python/tests/test_export_import.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mlx/export.cpp b/mlx/export.cpp index effc7a0c1..c9139e156 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -470,6 +470,9 @@ bool FunctionTable::match( if (x.dtype() != y.dtype()) { return false; } + if (x.ndim() != y.ndim()) { + return false; + } if (!shapeless && x.shape() != y.shape()) { return false; } diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 2b4b425ca..0190827bd 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -242,6 +242,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -267,6 +268,24 @@ class TestExportImport(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_export_import_shapeless(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(*args): + return sum(args) + + with mx.exporter(path, fun, shapeless=True) as exporter: + exporter(mx.array(1)) + exporter(mx.array(1), mx.array(2)) + exporter(mx.array(1), mx.array(2), mx.array(3)) + + f2 = mx.import_function(path) + self.assertEqual(f2(mx.array(1))[0].item(), 1) + self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) + self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) + with self.assertRaises(ValueError): + f2(mx.array(10), mx.array([5, 10, 20])) + if __name__ == "__main__": unittest.main() From 481349495b8c3d094eb699e678077bbe1406392d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 18 Feb 2025 13:43:09 -0800 Subject: [PATCH 025/156] GPU Hadamard for large N (#1879) --- mlx/backend/common/hadamard.h | 6 +- mlx/backend/metal/hadamard.cpp | 240 ++++++++++++++------------- mlx/backend/metal/kernels/hadamard.h | 41 +++-- mlx/ops.cpp | 13 +- python/tests/test_ops.py | 26 ++- 5 files changed, 198 insertions(+), 128 deletions(-) diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h index a8fed76b0..ba5c4e41e 100644 --- a/mlx/backend/common/hadamard.h +++ b/mlx/backend/common/hadamard.h @@ -99,7 +99,11 @@ inline std::pair decompose_hadamard(int n) { "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } return {n, m}; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index a7dfc5f17..89b970fce 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -1,9 +1,7 @@ // Copyright © 2024 Apple Inc. -#include - -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -15,7 +13,6 @@ namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; -constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M @@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void Hadamard::eval_gpu(const std::vector& inputs, array& out) { - auto& s = stream(); +void hadamard_mn_contiguous( + const array& x, + array& y, + int m, + int n1, + int n2, + float scale, + metal::Device& d, + const Stream& s) { + int n = n1 * n2; + int read_width_n1 = n1 == 2 ? 2 : 4; + int read_width_n2 = n2 == 2 ? 2 : 4; + int read_width_m = (n == 2 || m == 28) ? 2 : 4; + int max_radix_1 = std::min(n1, 16); + int max_radix_2 = std::min(n2, 16); + float scale_n1 = 1.0; + float scale_n2 = (m == 1) ? scale : 1.0; + float scale_m = scale; - auto& in = inputs[0]; + // n2 is a row contiguous power of 2 hadamard transform + MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); + MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); - std::vector copies; - // Only support the last axis for now - int axis = in.ndim() - 1; - auto check_input = [&copies, &s](const array& x) { - // TODO(alexbarron) pass strides to kernel to relax this constraint - bool no_copy = x.flags().row_contiguous; - if (no_copy) { - return x; - } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + // n1 is a strided power of 2 hadamard transform with stride n2 + MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); + MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); + + // m is a strided hadamard transform with stride n = n1 * n2 + MTL::Size group_dims_m( + std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); + MTL::Size grid_dims_m( + group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); + + // Make the kernel + std::string kname; + kname.reserve(32); + concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); + auto lib = d.get_library(kname, [&]() { + std::string kernel; + concatenate( + kernel, + metal::utils(), + gen_hadamard_codelet(m), + metal::hadamard(), + get_template_definition( + "n2" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n2, + max_radix_2, + read_width_n2)); + if (n1 > 1) { + kernel += get_template_definition( + "n1" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n1, + max_radix_1, + read_width_n1, + n2); } - }; - const array& in_contiguous = check_input(in); - - if (in_contiguous.is_donatable()) { - out.copy_shared_buffer(in_contiguous); - } else { - out.set_data(allocator::malloc(out.nbytes())); - } - - int n, m; - std::tie(n, m) = decompose_hadamard(in.shape(axis)); - - if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { - throw std::invalid_argument( - "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); - } - - int max_radix = std::min(n, 16); - // Use read_width 2 for m = 28 to avoid register spilling - int read_width = (n == 2 || m == 28) ? 2 : 4; - - std::ostringstream kname; - kname << "hadamard_" << n * m << "_" << type_to_name(out); - auto kernel_name = kname.str(); - auto& d = metal::device(s.device); - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - auto codelet = gen_hadamard_codelet(m); - kernel_source << metal::utils() << codelet << metal::hadamard(); - kernel_source << get_template_definition( - "n" + kernel_name, - "hadamard_n", - get_type_string(in.dtype()), - n, - max_radix, - read_width); - kernel_source << get_template_definition( - "m" + kernel_name, - "hadamard_m", - get_type_string(in.dtype()), - n, - m, - read_width); - return kernel_source.str(); + if (m > 1) { + kernel += get_template_definition( + "m" + kname, + "hadamard_m", + get_type_string(x.dtype()), + n, + m, + read_width_m); + } + return kernel; }); - int batch_size = in.size() / n; - int threads_per = n / max_radix; - - auto& compute_encoder = d.get_command_encoder(s.index); - - auto launch_hadamard = [&](const array& in, - array& out, - const std::string& kernel_name, - float scale) { - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - + // Launch the strided transform for n1 + if (n1 > 1) { + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(scale, 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - }; - - if (m > 1) { - // When m is greater than 1, we decompose the - // computation into two uploads to the GPU: - // - // e.g. len(x) = 12*4 = 48, m = 12, n = 4 - // - // y = h48 @ x - // - // Upload 1: - // tmp = a.reshape(12, 4) @ h4 - // - // Upload 2: - // y = h12 @ tmp - array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc(temp.nbytes())); - copies.push_back(temp); - - launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); - - // Metal sometimes reports 256 max threads per group for hadamard_m kernel - threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); - batch_size = in.size() / m / read_width / threads_per; - launch_hadamard(temp, out, "m" + kernel_name, scale_); - } else { - launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n1, 2); + compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } - d.add_temporaries(std::move(copies), s.index); + // Launch the transform for n2 + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n2" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(n1 > 1 ? y : x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n2, 2); + compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); + + // Launch the strided transform for m + if (m > 1) { + auto kernel = d.get_kernel("m" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(y, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_m, 2); + compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); + } +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + // Split the hadamard transform so that all of them work on vectors smaller + // than 8192 elements. + // + // We decompose it in the following way: + // + // n = m * n1 * n2 = m * 2^k1 * 2^k2 + // + // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 + auto [n, m] = decompose_hadamard(in.shape().back()); + int n1 = 1, n2 = n; + if (n > 8192) { + for (n2 = 2; n2 * n2 < n; n2 *= 2) { + } + n1 = n / n2; + } + + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); + } else { + copy_gpu(in, out, CopyType::General, s); + hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index 93e2fb8a8..9f2311c10 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e7abe12db..4aa5e88b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -473,8 +473,19 @@ array hadamard_transform( std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.ndim() > 0 ? a.shape(-1) : 1; + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Nothing to do for a scalar + if (n == 1) { + if (scale == 1) { + return a; + } + + return multiply(a, array(scale, dtype), s); + } + return array( a.shape(), dtype, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d840eac7d..d9e143d82 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase): h28 = parse_h_string(h28_str) + x = mx.array(5) + y = mx.hadamard_transform(x) + self.assertEqual(y.item(), 5) + + x = mx.array(5) + y = mx.hadamard_transform(x, scale=0.2) + self.assertEqual(y.item(), 1) + + x = mx.random.normal((8, 8, 1)) + y = mx.hadamard_transform(x) + self.assertTrue(mx.all(y == x).item()) + + # Too slow to compare to numpy so let's compare CPU to GPU + if mx.default_device() == mx.gpu: + rk = mx.random.key(42) + for k in range(14, 17): + for m in [1, 3, 5, 7]: + x = mx.random.normal((4, m * 2**k), key=rk) + y1 = mx.hadamard_transform(x, stream=mx.cpu) + y2 = mx.hadamard_transform(x, stream=mx.gpu) + self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6) + np.random.seed(7) - tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14)) for dtype, m, k in tests: # skip large m=28 cases because they're very slow in NumPy - if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + if m > 1 and k > 8: continue with self.subTest(dtype=dtype, m=m, k=k): n = m * 2**k From 9c5e7da5079cf98f48df150c8bed5c3c0043d22c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 May 2025 15:08:50 -0700 Subject: [PATCH 026/156] fix compile merging (#2150) --- mlx/compile.cpp | 9 +++++++++ tests/compile_tests.cpp | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ff5c8f9e..2baeb6fcf 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } + + // If src is a parent of dst, remove it from dst's parents + for (auto it = pairs.begin(); it != pairs.end();) { + if (it->first.id() == src.id()) { + it = pairs.erase(it); + } else { + it++; + } + } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 66511682d..96552ef9d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") { out = cfun2({array(0)}); CHECK_EQ(out[0].item(), 3); } + +TEST_CASE("test compile with no-ops") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(stop_gradient(abs(inputs[0])))}; + }; + auto in = array(1.0); + auto out = compile(fun)({in})[0]; + CHECK_EQ(out.inputs()[0].id(), in.id()); +} From 825124af8ffd32d0f2f7d8f8eca83c8c3eb510a7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 06:15:04 -0700 Subject: [PATCH 027/156] fix bw for elementwise ops (#2151) * fix bw for elementwise ops * add compile * fix * fix * fix * fix --- mlx/backend/metal/binary.cpp | 15 ++++-- mlx/backend/metal/compiled.cpp | 38 +++++++++---- mlx/backend/metal/copy.cpp | 27 +++++++--- mlx/backend/metal/kernels/binary.h | 51 ++++++++++++------ mlx/backend/metal/kernels/binary_two.h | 75 ++++++++++++++++---------- mlx/backend/metal/kernels/copy.h | 34 ++++++++---- mlx/backend/metal/kernels/ternary.h | 17 ++++-- mlx/backend/metal/kernels/unary.h | 17 ++++-- mlx/backend/metal/kernels/utils.h | 8 +++ mlx/backend/metal/ternary.cpp | 14 +++-- mlx/backend/metal/unary.cpp | 17 ++++-- mlx/backend/metal/utils.h | 8 +++ 12 files changed, 232 insertions(+), 89 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -90,7 +90,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +137,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..db20f938c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -64,6 +64,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -83,6 +84,9 @@ inline void build_kernel( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +96,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -110,6 +115,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -193,7 +201,7 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { + if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } @@ -272,6 +280,7 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -284,7 +293,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -295,7 +306,8 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -468,6 +480,13 @@ void Compiled::eval_gpu( if (!contiguous) { compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,12 +496,13 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + int work_per_thread = get_work_per_thread(outputs[0].dtype()); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..ee004359f 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -104,6 +104,8 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(in.dtype()); } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,13 +167,19 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -214,14 +222,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); + int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..ffc33ad82 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,85 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..e261d33c4 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,103 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } template diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..2469d1f3d 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,53 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,32 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..b5eaab2e9 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,28 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + out[index + i] = Op()(in[index + i]); + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } template < diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 1170d5576..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -106,13 +106,19 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..368e693a9 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -34,18 +34,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { + work_per_thread = get_work_per_thread(in.dtype()); kernel_name = (large ? "v2" : "v"); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +76,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..f9245a6d6 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + } // namespace mlx::core From af705590ac9335105a5a026de4fc68ee6e747a9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 13:13:03 -0700 Subject: [PATCH 028/156] fix batched vector sdpa (#2152) --- mlx/backend/metal/kernels/sdpa_vector.h | 12 +- .../metal/scaled_dot_product_attention.cpp | 103 ++++++++++-------- python/tests/test_fast_sdpa.py | 40 +++++++ 3 files changed, 105 insertions(+), 50 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..8258e9c14 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01..d75e6d87d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -154,9 +154,9 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); @@ -199,11 +199,10 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); @@ -238,9 +237,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -302,11 +302,10 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); @@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks if arr is row contiguous or the sequence and head dimension are - // transposed - auto is_contiguous_or_head_seq_transposed = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && - (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; @@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible - if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && - q.size() == o.size()) { + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - if (o.shape(2) == 1) { - o.set_data(allocator::malloc(o.nbytes())); - } else { - auto strides = o.strides(); - strides[2] = o.shape(1) * o.shape(3); - strides[1] = o.shape(3); - auto flags = q.flags(); - flags.row_contiguous = q.shape(1) == 1; - o.set_data( - allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); - } + o.set_data(allocator::malloc(o.nbytes())); } - auto mask = - inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index d35a2b1da..8f55d41e3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_vector_batched(self): + D = 64 + q = mx.random.normal(shape=(2, 1, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) + v = mx.random.normal(shape=(2, 2, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + ref = mlx_ref_attn(q, k, v, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + class TestSDPA(mlx_tests.MLXTestCase): @property From 1683975acf2f007ba94a0a53241149474f0c070b Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 6 May 2025 05:45:29 +0900 Subject: [PATCH 029/156] Move common gpu primitives to backend/gpu (#2145) --- mlx/CMakeLists.txt | 1 + mlx/backend/gpu/CMakeLists.txt | 5 + mlx/backend/gpu/copy.cpp | 49 ++++ mlx/backend/{metal => gpu}/copy.h | 2 + mlx/backend/gpu/primitives.cpp | 217 ++++++++++++++++++ mlx/backend/gpu/slicing.cpp | 44 ++++ mlx/backend/{metal => gpu}/slicing.h | 0 mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/copy.cpp | 44 +--- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/backend/metal/distributed.cpp | 2 +- mlx/backend/metal/fft.cpp | 4 +- mlx/backend/metal/hadamard.cpp | 2 +- mlx/backend/metal/indexing.cpp | 2 +- mlx/backend/metal/logsumexp.cpp | 2 +- mlx/backend/metal/matmul.cpp | 2 +- mlx/backend/metal/normalization.cpp | 2 +- mlx/backend/metal/primitives.cpp | 182 +-------------- mlx/backend/metal/quantized.cpp | 2 +- mlx/backend/metal/reduce.cpp | 2 +- mlx/backend/metal/rope.cpp | 2 +- .../metal/scaled_dot_product_attention.cpp | 2 +- mlx/backend/metal/scan.cpp | 2 +- mlx/backend/metal/slicing.cpp | 39 +--- mlx/backend/metal/softmax.cpp | 2 +- mlx/backend/metal/sort.cpp | 2 +- 26 files changed, 340 insertions(+), 277 deletions(-) create mode 100644 mlx/backend/gpu/CMakeLists.txt create mode 100644 mlx/backend/gpu/copy.cpp rename mlx/backend/{metal => gpu}/copy.h (98%) create mode 100644 mlx/backend/gpu/primitives.cpp create mode 100644 mlx/backend/gpu/slicing.cpp rename mlx/backend/{metal => gpu}/slicing.h (100%) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 465954d6f..00898e73e 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,6 +47,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt new file mode 100644 index 000000000..0396ae03a --- /dev/null +++ b/mlx/backend/gpu/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp new file mode 100644 index 000000000..6127ac921 --- /dev/null +++ b/mlx/backend/gpu/copy.cpp @@ -0,0 +1,49 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/gpu/copy.h similarity index 98% rename from mlx/backend/metal/copy.h rename to mlx/backend/gpu/copy.h index 37c60df42..020f579e4 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/gpu/copy.h @@ -5,6 +5,8 @@ #include "mlx/backend/common/copy.h" #include "mlx/stream.h" +#include + namespace mlx::core { // Generic copy inplace diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp new file mode 100644 index 000000000..cd9296075 --- /dev/null +++ b/mlx/backend/gpu/primitives.cpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#include + +#define MLX_PROFILER_RANGE(message) + +namespace mlx::core { + +namespace { + +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsStrided::eval_gpu"); + eval(inputs, out); +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsType::eval_gpu"); + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Broadcast::eval_gpu"); + eval(inputs, out); +} + +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Concatenate::eval_gpu"); + concatenate_gpu(inputs, out, axis_, stream()); +} + +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Contiguous::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { + out.copy_shared_buffer(in); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Copy::eval_gpu"); + eval(inputs, out); +} + +void CustomTransforms::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Depends::eval_gpu"); + eval(inputs, outputs); +} + +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); + eval(inputs, out); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Full::eval_gpu"); + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Flatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); + eval(inputs, out); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Reshape::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Split::eval_gpu"); + eval(inputs, outputs); +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Slice::eval_gpu"); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + slice_gpu(in, out, start_indices_, strides_, stream()); +} + +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Squeeze::eval_gpu"); + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("StopGradient::eval_gpu"); + eval(inputs, out); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Transpose::eval_gpu"); + eval(inputs, out); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Unflatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void View::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("View::eval_gpu"); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp new file mode 100644 index 000000000..fde2a01cd --- /dev/null +++ b/mlx/backend/gpu/slicing.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s) { + slice(in, out, start_indices, strides); +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s) { + // Fill output with val + fill_gpu(val, out, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/gpu/slicing.h similarity index 100% rename from mlx/backend/metal/slicing.h rename to mlx/backend/gpu/slicing.h diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9075ea4c5..ae31a6cff 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -5,7 +5,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index ee004359f..8dfe15c11 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,35 +1,15 @@ // Copyright © 2023-2024 Apple Inc. -#include - +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} - -void copy_gpu(const array& in, array& out, CopyType ctype) { - copy_gpu(in, out, ctype, out.primitive().stream()); -} - void copy_gpu_inplace( const array& in, array& out, @@ -184,28 +164,6 @@ void copy_gpu_inplace( } } -void copy_gpu_inplace( - const array& in, - array& out, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); -} - -void copy_gpu_inplace( - const array& in, - array& out, - const Strides& i_strides, - int64_t i_offset, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); -} - void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8a672289a..ea4f258cc 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 82e8fff7d..a800d2e0f 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,7 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 153c62c02..011eb7ebb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -7,10 +7,10 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 89b970fce..65a877151 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -3,7 +3,7 @@ #include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a263051..cccfd908a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 4901190e1..e53bc58d9 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f55d20c9f..71221f8d9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -7,7 +7,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c1d993d2a..21142183e 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..860e9ddd7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,10 +7,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - static array compute_dynamic_offset( const array& indices, const Strides& strides, @@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void AsType::eval_gpu(const std::vector& inputs, array& out) { - CopyType ctype = - inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; - copy_gpu(inputs[0], out, ctype); -} - -void AsStrided::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Broadcast::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - concatenate_gpu(inputs, out, axis_, stream()); -} - -void Contiguous::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; - if (in.buffer_size() <= out.nbytes() + extra_bytes && - (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous))) { - out.copy_shared_buffer(in); - } else { - copy_gpu(in, out, CopyType::General); - } -} - -void Copy::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void CustomTransforms::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Depends::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Full::eval_gpu(const std::vector& inputs, array& out) { - auto in = inputs[0]; - CopyType ctype; - if (in.data_size() == 1) { - ctype = CopyType::Scalar; - } else if (in.flags().contiguous) { - ctype = CopyType::Vector; - } else { - ctype = CopyType::General; - } - copy_gpu(in, out, ctype); -} - -void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Flatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Unflatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } -void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Pad::eval_gpu(const std::vector& inputs, array& out) { - // Inputs must be base input array and scalar val array - assert(inputs.size() == 2); - auto& in = inputs[0]; - auto& val = inputs[1]; - - // Padding value must be a scalar - assert(val.size() == 1); - - // Padding value, input and output must be of the same type - assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - - pad_gpu(in, val, out, axes_, low_pad_size_, stream()); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void Reshape::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Split::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Slice::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - slice_gpu(in, out, start_indices_, strides_, stream()); -} - void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { /* const Stream& s = */ stream()); } -void Squeeze::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void StopGradient::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Transpose::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -537,35 +390,4 @@ void LUF::eval_gpu( throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } -void View::eval_gpu(const std::vector& inputs, array& out) { - auto& in = inputs[0]; - auto ibytes = size_of(in.dtype()); - auto obytes = size_of(out.dtype()); - // Conditions for buffer copying (disjunction): - // - type size is the same - // - type size is smaller and the last axis is contiguous - // - the entire array is row contiguous - if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || - in.flags().row_contiguous) { - auto strides = in.strides(); - for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { - strides[i] *= ibytes; - strides[i] /= obytes; - } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); - } else { - auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc(tmp.nbytes())); - copy_gpu_inplace(in, tmp, CopyType::General, stream()); - - auto flags = out.flags(); - flags.contiguous = true; - flags.row_contiguous = true; - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); - } -} - } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..11a2355cc 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -4,7 +4,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c5650bdd7..8cb55ba58 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 060758333..d8201afe6 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d75e6d87d..3c7b7ff19 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b1800fea9..3c4051105 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 6ab08a108..3e1a8b541 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,21 +2,12 @@ #include -#include "mlx/backend/common/slicing.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" namespace mlx::core { -void slice_gpu( - const array& in, - array& out, - const Shape& start_indices, - const Shape& strides, - const Stream& s) { - slice(in, out, start_indices, strides); -} - void concatenate_gpu( const std::vector& inputs, array& out, @@ -48,30 +39,4 @@ void concatenate_gpu( } } -void pad_gpu( - const array& in, - const array& val, - array& out, - const std::vector& axes, - const Shape& low_pad_size, - const Stream& s) { - // Fill output with val - fill_gpu(val, out, s); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { - auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; - data_offset += out.strides()[ax] * low_pad_size[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 224721a50..59662b05d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 543dfd180..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -2,7 +2,7 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" From 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 17:30:50 -0700 Subject: [PATCH 030/156] fix input coherent kernel launch (#2153) --- mlx/backend/metal/fence.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index d4a88d983..5abdf7309 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -138,7 +138,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(group_dims, grid_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Barrier on previous kernels compute_encoder.barrier(); From 0cae0bdac83bbf5b3d1da3ca53f1f7eb95981d30 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 7 May 2025 13:26:46 +0900 Subject: [PATCH 031/156] CUDA backend: backbone (#2075) --- CMakeLists.txt | 5 + mlx/CMakeLists.txt | 10 +- mlx/backend/cuda/CMakeLists.txt | 57 ++++++ mlx/backend/cuda/allocator.cpp | 154 ++++++++++++++ mlx/backend/cuda/allocator.h | 58 ++++++ mlx/backend/cuda/copy.cpp | 26 +++ mlx/backend/cuda/device.cpp | 117 +++++++++++ mlx/backend/cuda/device.h | 131 ++++++++++++ mlx/backend/cuda/dtype_utils.cuh | 35 ++++ mlx/backend/cuda/eval.cpp | 68 +++++++ mlx/backend/cuda/event.cu | 265 +++++++++++++++++++++++++ mlx/backend/cuda/event.h | 66 ++++++ mlx/backend/cuda/fence.cu | 70 +++++++ mlx/backend/cuda/kernels/arange.cuh | 15 ++ mlx/backend/cuda/kernels/fp16_math.cuh | 107 ++++++++++ mlx/backend/cuda/primitives.cu | 163 +++++++++++++++ mlx/backend/cuda/slicing.cpp | 15 ++ mlx/backend/cuda/utils.cpp | 26 +++ mlx/backend/cuda/utils.h | 36 ++++ mlx/backend/cuda/worker.cpp | 90 +++++++++ mlx/backend/cuda/worker.h | 68 +++++++ tests/CMakeLists.txt | 2 +- 22 files changed, 1582 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/CMakeLists.txt create mode 100644 mlx/backend/cuda/allocator.cpp create mode 100644 mlx/backend/cuda/allocator.h create mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/device.cpp create mode 100644 mlx/backend/cuda/device.h create mode 100644 mlx/backend/cuda/dtype_utils.cuh create mode 100644 mlx/backend/cuda/eval.cpp create mode 100644 mlx/backend/cuda/event.cu create mode 100644 mlx/backend/cuda/event.h create mode 100644 mlx/backend/cuda/fence.cu create mode 100644 mlx/backend/cuda/kernels/arange.cuh create mode 100644 mlx/backend/cuda/kernels/fp16_math.cuh create mode 100644 mlx/backend/cuda/primitives.cu create mode 100644 mlx/backend/cuda/slicing.cpp create mode 100644 mlx/backend/cuda/utils.cpp create mode 100644 mlx/backend/cuda/utils.h create mode 100644 mlx/backend/cuda/worker.cpp create mode 100644 mlx/backend/cuda/worker.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e2002fc94..ab8aea443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -83,6 +84,10 @@ if(MLX_BUILD_METAL) set(QUARTZ_LIB "-framework QuartzCore") endif() +if(MLX_BUILD_CUDA) + enable_language(CUDA) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 00898e73e..4ba9b33dd 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) +endif() + +if(MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +endif() + +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) +else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt new file mode 100644 index 000000000..54d651005 --- /dev/null +++ b/mlx/backend/cuda/CMakeLists.txt @@ -0,0 +1,57 @@ +# Filename rules in cuda backend: +# +# * Use .cu/.cuh if code contains device code, and .cpp/.h if not. +# * Device-only kernel code should be put in kernels/ subdir. +# * Files in kernels/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) + +# Enable defining device lambda functions. +target_compile_options(mlx + PRIVATE "$<$:--extended-lambda>") + +# Compute capability 7 is required for synchronization between CPU/GPU with +# managed memory. TODO: Add more architectures for potential performance gain. +set(MLX_CUDA_ARCHITECTURES + "75;80" + CACHE STRING "CUDA architectures") +message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") +set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES + "${MLX_CUDA_ARCHITECTURES}") + +# Use fixed version of CCCL. +FetchContent_Declare( + cccl + URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") +FetchContent_MakeAvailable(cccl) +target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") + +# Use fixed version of NVTX. +FetchContent_Declare( + nvtx3 + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.1 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(nvtx3) +target_link_libraries(mlx PUBLIC $) + +# Make cuda runtime APIs available in non-cuda files. +find_package(CUDAToolkit REQUIRED) +target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + +# Suppress nvcc warnings on MLX headers. +target_compile_options(mlx PRIVATE $<$:-Xcudafe + --diag_suppress=997>) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp new file mode 100644 index 000000000..203534e21 --- /dev/null +++ b/mlx/backend/cuda/allocator.cpp @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/worker.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +CudaAllocator::CudaAllocator() { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; +} + +Buffer CudaAllocator::malloc(size_t size) { + // TODO: Check memory limit. + auto* buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error( + fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + std::lock_guard lock(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; +} + +void CudaAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + // If free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([buffer]() { allocator().free(buffer); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + size_t size = buf->size; + cudaFree(buf->data); + delete buf; + std::lock_guard lock(mutex_); + active_memory_ -= size; +} + +size_t CudaAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void CudaAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +size_t CudaAllocator::get_active_memory() const { + return active_memory_; +} + +size_t CudaAllocator::get_peak_memory() const { + return peak_memory_; +} + +void CudaAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t CudaAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t CudaAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +CudaAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of CudaAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static CudaAllocator* allocator_ = new CudaAllocator; + return *allocator_; +} + +} // namespace cu + +namespace allocator { + +Allocator& allocator() { + return cu::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return cu::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return cu::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return cu::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return cu::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return cu::allocator().get_memory_limit(); +} + +// TODO: Implement buffer cache. +size_t get_cache_memory() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h new file mode 100644 index 000000000..6c418ee7e --- /dev/null +++ b/mlx/backend/cuda/allocator.h @@ -0,0 +1,58 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class Worker; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In cuda freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + + private: + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +CudaAllocator& allocator(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp new file mode 100644 index 000000000..d0413d989 --- /dev/null +++ b/mlx/backend/cuda/copy.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& strides_in_pre, + const Strides& strides_out_pre, + int64_t inp_offset, + int64_t out_offset, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { + throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); +} + +void fill_gpu(const array& val, array& out, const Stream& s) { + throw std::runtime_error("fill_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp new file mode 100644 index 000000000..a28ffa35e --- /dev/null +++ b/mlx/backend/cuda/device.cpp @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/metal/metal.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} + +void DeviceStream::synchronize() { + cudaStreamSynchronize(stream_); +} + +cudaStream_t DeviceStream::schedule_cuda_stream() { + // TODO: Return a stream that maximizes parallelism. + return stream_; +} + +cudaStream_t DeviceStream::last_cuda_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } + return *encoder_; +} + +Device::Device(int device) : device_(device) { + // Validate the requirements of device. + int attr = 0; + cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_); + if (attr != 1) { + throw std::runtime_error(fmt::format( + "Device {} does not support synchronization in managed memory.", + device_)); + } +} + +void Device::make_current() { + // We need to set/get current CUDA device very frequently, cache it to reduce + // actual calls of CUDA APIs. This function assumes single-thread in host. + static int current = 0; + if (current != device_) { + CHECK_CUDA_ERROR(cudaSetDevice(device_)); + current = device_; + } +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::end_encoding() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Put completion handlers in a batch. + worker_.end_batch(); + + // Signaling kernel completion is expensive, delay until enough batches. + // TODO: This number is arbitrarily picked, profile for a better stragety. + if (worker_.uncommited_batches() > 8) { + commit(); + } +} + +void CommandEncoder::commit() { + worker_.commit(stream_.last_cuda_stream()); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +DeviceStream& get_stream(Stream s) { + return device(s.device).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace cu + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h new file mode 100644 index 000000000..a65a87d54 --- /dev/null +++ b/mlx/backend/cuda/device.h @@ -0,0 +1,131 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include + +#include + +namespace mlx::core::cu { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a cuda stream for launching kernels. + cudaStream_t schedule_cuda_stream(); + + // Return the last cuda stream used. + cudaStream_t last_cuda_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + CudaStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, required by some cuda calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int cuda_device() const { + return device_; + } + + private: + int device_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a cuda stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); + } + + template + void launch_kernel(cudaStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_cuda_error("kernel launch", cudaGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/dtype_utils.cuh new file mode 100644 index 000000000..9b7f8ba65 --- /dev/null +++ b/mlx/backend/cuda/dtype_utils.cuh @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Maps CPU types to CUDA types. +template +struct CTypeToCudaType { + using type = T; +}; + +template <> +struct CTypeToCudaType { + using type = __half; +}; + +template <> +struct CTypeToCudaType { + using type = __nv_bfloat16; +}; + +template <> +struct CTypeToCudaType { + using type = cuComplex; +}; + +template +using cuda_type_t = typename CTypeToCudaType::type; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp new file mode 100644 index 000000000..b309ad60e --- /dev/null +++ b/mlx/backend/cuda/eval.cpp @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream s) { + // Force initalization of cuda, so cuda runtime get destroyed at last. + cudaFree(nullptr); + // Ensure the static stream objects get created. + cu::get_command_encoder(s); + // The main thread is safe to free buffers. + cu::allocator().register_this_thread(); +} + +void eval(array& arr) { + nvtx3::scoped_range r("gpu::eval"); + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + if (encoder.has_gpu_work()) { + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + } + encoder.end_encoding(); +} + +void finalize(Stream s) { + nvtx3::scoped_range r("gpu::finalize"); + cu::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + nvtx3::scoped_range r("gpu::synchronize"); + cu::get_stream(s).synchronize(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu new file mode 100644 index 000000000..a487f45b4 --- /dev/null +++ b/mlx/backend/cuda/event.cu @@ -0,0 +1,265 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace cu { + +/////////////////////////////////////////////////////////////////////////////// +// CudaEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +// Cuda event managed with RAII. +class CudaEventHandle { + public: + CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags( + &event_, cudaEventDisableTiming | cudaEventBlockingSync)); + } + + ~CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + } + + CudaEventHandle(const CudaEventHandle&) = delete; + CudaEventHandle& operator=(const CudaEventHandle&) = delete; + + operator cudaEvent_t() const { + return event_; + } + + private: + cudaEvent_t event_; +}; + +CudaEvent::CudaEvent() : event_(std::make_shared()) {} + +void CudaEvent::wait() { + nvtx3::scoped_range r("cu::CudaEvent::wait"); + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaEventSynchronize(*event_); +} + +void CudaEvent::wait(cudaStream_t stream) { + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaStreamWaitEvent(stream, *event_); +} + +void CudaEvent::wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { wait(); }); + } else { + wait(cu::get_stream(s).last_cuda_stream()); + } +} + +void CudaEvent::record(cudaStream_t stream) { + cudaEventRecord(*event_, stream); + recorded_ = true; +} + +void CudaEvent::record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on cpu stream."); + } else { + record(cu::get_stream(s).last_cuda_stream()); + } +} + +bool CudaEvent::completed() const { + return cudaEventQuery(*event_) == cudaSuccess; +} + +/////////////////////////////////////////////////////////////////////////////// +// SharedEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { + uint64_t current; + while ((current = ac->load()) < value) { + ac->wait(current); + } +} + +__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { + ac->store(value); + ac->notify_all(); +} + +__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_wait(ac, value); +} + +__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_signal(ac, value); +} + +} // namespace + +SharedEvent::SharedEvent() { + // Allocate cuda::atomic on managed memory. + allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); + Atomic* ac = static_cast(buffer.raw_ptr()); + new (ac) Atomic(0); + ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ptr->~Atomic(); + allocator::free(buffer); + }); +} + +void SharedEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait"); + event_wait(ac_.get(), value); +} + +void SharedEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +void SharedEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal"); + event_signal(ac_.get(), value); +} + +void SharedEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +bool SharedEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); + return ac_->load() >= value; +} + +uint64_t SharedEvent::value() const { + nvtx3::scoped_range r("cu::SharedEvent::value"); + return ac_->load(); +} + +} // namespace cu + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + // CudaEvent is preferred when possible because it is fast, however we have + // to fallback to SharedEvent in following cases: + // 1. the event is used to wait/signal a cpu stream; + // 2. signal value other than 1 has been specified. + std::unique_ptr cuda; + std::unique_ptr shared; + + bool is_created() const { + return cuda || shared; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + nvtx3::mark("Using slow SharedEvent"); + shared = std::make_unique(); + } else { + cuda = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(); + } else { + event->shared->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->shared->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + if (event->cuda) { + assert(value() == 1); + event->cuda->record(s); + } else { + event->shared->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (!event->is_created()) { + return false; + } + if (event->cuda) { + assert(value() == 1); + return event->cuda->recorded() && event->cuda->completed(); + } else { + return event->shared->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h new file mode 100644 index 000000000..4b56e2e3b --- /dev/null +++ b/mlx/backend/cuda/event.h @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::cu { + +class CudaEventHandle; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(); + + void wait(); + void wait(cudaStream_t stream); + void wait(Stream s); + void record(cudaStream_t stream); + void record(Stream s); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + bool recorded() const { + return recorded_; + } + + private: + bool recorded_{false}; + std::shared_ptr event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class SharedEvent { + public: + using Atomic = cuda::atomic; + + SharedEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + const std::shared_ptr& atomic() const { + return ac_; + } + + private: + std::shared_ptr ac_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu new file mode 100644 index 000000000..091b252c1 --- /dev/null +++ b/mlx/backend/cuda/fence.cu @@ -0,0 +1,70 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/fence.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace { + +__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { + while (true) { + // In theory the atomic_thread_fence is not needed, but for CUDA 11 without + // it the load() may never return new value. + cuda::atomic_thread_fence(cuda::memory_order_seq_cst); + uint64_t current = ac->load(); + if (current >= value) { + break; + } + } +} + +__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { + busy_wait(ac, value); +} + +} // namespace + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: + // https://github.com/ml-explore/mlx/issues/2137 + const auto& ac = fence->event.atomic(); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [ac, count = fence->count]() { + nvtx3::scoped_range r("Fence::wait()"); + busy_wait(ac.get(), count); + }); + } else { + nvtx3::scoped_range r("Fence::wait(s)"); + auto& encoder = cu::get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { + busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); + }); + encoder.add_completed_handler([ac]() {}); + encoder.end_encoding(); + } +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/kernels/arange.cuh new file mode 100644 index 000000000..53c261e34 --- /dev/null +++ b/mlx/backend/cuda/kernels/arange.cuh @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh new file mode 100644 index 000000000..931c55ff7 --- /dev/null +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -0,0 +1,107 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Missing C++ operator overrides for CUDA 7. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +#define MLX_DEFINE_BF16_OP(OP) \ + __forceinline__ __device__ __nv_bfloat16 operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +#define MLX_DEFINE_BF16_CMP(OP) \ + __forceinline__ __device__ bool operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +MLX_DEFINE_BF16_OP(+) +MLX_DEFINE_BF16_OP(-) +MLX_DEFINE_BF16_OP(*) +MLX_DEFINE_BF16_OP(/) +MLX_DEFINE_BF16_CMP(>) +MLX_DEFINE_BF16_CMP(<) +MLX_DEFINE_BF16_CMP(>=) +MLX_DEFINE_BF16_CMP(<=) + +#undef MLX_DEFINE_BF16_OP +#undef MLX_DEFINE_BF16_CMP + +#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +/////////////////////////////////////////////////////////////////////////////// +// Additional C++ operator overrides between half types and native types. +/////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool is_integral_except = + cuda::std::is_integral_v && !cuda::std::is_same_v; + +template +constexpr bool is_arithmetic_except = + cuda::std::is_arithmetic_v && !cuda::std::is_same_v; + +#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ + return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ + return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ + } + +#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(HALF x, T y) { \ + return HALF2FLOAT(x) OP static_cast(y); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(T x, HALF y) { \ + return static_cast(y) OP HALF2FLOAT(x); \ + } + +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) +MLX_DEFINE_HALF_CMP(__half, __half2float, <) +MLX_DEFINE_HALF_CMP(__half, __half2float, >) +MLX_DEFINE_HALF_CMP(__half, __half2float, <=) +MLX_DEFINE_HALF_CMP(__half, __half2float, >=) +MLX_DEFINE_HALF_CMP(__half, __half2float, ==) +MLX_DEFINE_HALF_CMP(__half, __half2float, !=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) + +#undef MLX_DEFINE_HALF_OP +#undef MLX_DEFINE_HALF_CMP + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu new file mode 100644 index 000000000..dc6edf606 --- /dev/null +++ b/mlx/backend/cuda/primitives.cu @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernels/arange.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + assert(inputs.size() == 0); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_output_array(out); + encoder.launch_kernel([&, this](cudaStream_t stream) { + MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); + }); +} + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +NO_GPU(Abs) +NO_GPU(Add) +NO_GPU(AddMM) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTan2) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(BitwiseBinary) +NO_GPU(BitwiseInvert) +NO_GPU(BlockMaskedMM) +NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) +NO_GPU(Conjugate) +NO_GPU(Convolution) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU(Divide) +NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(Remainder) +NO_GPU(Equal) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(Expm1) +NO_GPU(FFT) +NO_GPU(Floor) +NO_GPU(Gather) +NO_GPU(GatherAxis) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Greater) +NO_GPU(GreaterEqual) +NO_GPU(Hadamard) +NO_GPU(Imag) +NO_GPU(Less) +NO_GPU(LessEqual) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(LogicalAnd) +NO_GPU(LogicalOr) +NO_GPU(LogAddExp) +NO_GPU(LogSumExp) +NO_GPU_MULTI(LUF) +NO_GPU(Matmul) +NO_GPU(Maximum) +NO_GPU(Minimum) +NO_GPU(Multiply) +NO_GPU(Negative) +NO_GPU(NotEqual) +NO_GPU(Partition) +NO_GPU(Power) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(RandomBits) +NO_GPU(Real) +NO_GPU(Reduce) +NO_GPU(Round) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(ScatterAxis) +NO_GPU(Select) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(SliceUpdate) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU(Square) +NO_GPU(Sqrt) +NO_GPU(Subtract) +NO_GPU_MULTI(SVD) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) + +namespace fast { +NO_GPU_MULTI(LayerNorm) +NO_GPU_MULTI(LayerNormVJP) +NO_GPU_MULTI(RMSNorm) +NO_GPU_MULTI(RMSNormVJP) +NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp new file mode 100644 index 000000000..bfa742c74 --- /dev/null +++ b/mlx/backend/cuda/slicing.cpp @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp new file mode 100644 index 000000000..2a11a518e --- /dev/null +++ b/mlx/backend/cuda/utils.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} + +CudaStream::~CudaStream() { + CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); +} + +void check_cuda_error(const char* name, cudaError_t err) { + if (err != cudaSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, cudaGetErrorString(err))); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h new file mode 100644 index 000000000..58d508765 --- /dev/null +++ b/mlx/backend/cuda/utils.h @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +namespace cu { +class Device; +} + +// Cuda stream managed with RAII. +class CudaStream { + public: + explicit CudaStream(cu::Device& device); + ~CudaStream(); + + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; + + operator cudaStream_t() const { + return stream_; + } + + private: + cudaStream_t stream_; +}; + +// Throw exception if the cuda API does not succeed. +void check_cuda_error(const char* name, cudaError_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +} // namespace mlx::core diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp new file mode 100644 index 000000000..64b5c7679 --- /dev/null +++ b/mlx/backend/cuda/worker.cpp @@ -0,0 +1,90 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +Worker::Worker() + : signal_stream_(device(mlx::core::Device::gpu)), + worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(worker_mutex_); + stop_ = true; + } + worker_event_.signal(batch_ + 1); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::consume_in_this_thread() { + for (auto& task : pending_tasks_) { + task(); + } + pending_tasks_.clear(); +} + +void Worker::end_batch() { + batch_++; + { + std::lock_guard lock(worker_mutex_); + worker_tasks_[batch_] = std::move(pending_tasks_); + } + uncommited_batches_++; +} + +void Worker::commit() { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + worker_event_.signal(batch_); +} + +void Worker::commit(cudaStream_t stream) { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + // Signal the |worker_event_| in |signal_stream_| after the kernels in + // |stream_| finish running. + signal_event_.record(stream); + signal_event_.wait(signal_stream_); + worker_event_.signal(signal_stream_, batch_); +} + +void Worker::thread_fn() { + // The worker thread is safe to free buffers. + allocator().register_this_thread(); + + while (!stop_) { + uint64_t batch = worker_event_.value(); + Tasks tasks; + { + std::lock_guard lock(worker_mutex_); + // Move tasks in signaled batches. + auto end = worker_tasks_.upper_bound(batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + for (auto& task : tasks) { + task(); + } + worker_event_.wait(batch + 1); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h new file mode 100644 index 000000000..d28e22e95 --- /dev/null +++ b/mlx/backend/cuda/worker.h @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Put pending tasks in a batch. + void end_batch(); + + // Inform worker thread to run current batches now. + void commit(); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + // Return how many batches have been added but not committed yet. + size_t uncommited_batches() const { + return uncommited_batches_; + } + + private: + void thread_fn(); + + uint64_t batch_{0}; + size_t uncommited_batches_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + // Worker thread. + SharedEvent worker_event_; + std::thread worker_; + std::mutex worker_mutex_; + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; +}; + +} // namespace mlx::core::cu diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cf0ba3d5d..cb174865d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if(MLX_BUILD_METAL) +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) set(METAL_TEST_SOURCES gpu_tests.cpp) endif() From a7fae8a176fad114c89ca66ed0e0be8f3064e3e8 Mon Sep 17 00:00:00 2001 From: ATurker <53705368+aturker1@users.noreply.github.com> Date: Fri, 9 May 2025 20:26:52 +0300 Subject: [PATCH 032/156] fix: conv_general differences between gpu, cpu (#2070) * fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun --- mlx/backend/cpu/conv.cpp | 574 +++++++++++++++++++++---------------- mlx/backend/metal/conv.cpp | 6 +- mlx/ops.cpp | 1 + mlx/primitives.cpp | 48 ++-- mlx/primitives.h | 12 +- python/tests/test_conv.py | 42 +++ 6 files changed, 413 insertions(+), 270 deletions(-) diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index d52f92f8b..e5636b3b8 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -22,7 +22,8 @@ void slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -60,7 +61,8 @@ void slow_conv_1D( out_stride_O = out.strides()[2], flip, - padding = padding[0], + padding_lo = padding_lo[0], + padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { @@ -77,7 +79,7 @@ void slow_conv_1D( const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_stride - padding + wh_flip * wt_dilation; + int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); @@ -109,7 +111,8 @@ void slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -120,230 +123,235 @@ void slow_conv_2D( encoder.set_input_array(wt); encoder.set_output_array(out); - encoder.dispatch([st_wt_ptr = wt.data(), - st_in_ptr = in.data(), - st_out_ptr = out.data(), + encoder.dispatch( + [st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - N = in.shape( - 0), // Batch size, should be the same as out.shape(0) - iH = 1 + - in_dilation[0] * (in.shape(1) - 1), // Input spatial dim - iW = 1 + - in_dilation[1] * (in.shape(2) - 1), // Input spatial dim - C = in.shape(3), // In channels - oH = out.shape(1), // Output spatial dim - oW = out.shape(2), // Output spatial dim - O = wt.shape(0), // Out channels - wH = wt.shape(1), // Weight spatial dim - wW = wt.shape(2), // Weight spatial dim + N = in.shape(0), // Batch size, should be the same as out.shape(0) + iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - groups = in.shape(3) / wt.shape(3), - C_per_group = wt.shape(3), + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - in_stride_N = in.strides()[0], - in_stride_H = in.strides()[1], - in_stride_W = in.strides()[2], - in_stride_C = in.strides()[3], + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - wt_stride_O = wt.strides()[0], - wt_stride_H = wt.strides()[1], - wt_stride_W = wt.strides()[2], - wt_stride_C = wt.strides()[3], + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - out_stride_N = out.strides()[0], - out_stride_H = out.strides()[1], - out_stride_W = out.strides()[2], - out_stride_O = out.strides()[3], + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - padding, - wt_strides, - wt_dilation, - in_dilation, - flip]() mutable { - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - const int O_per_group = O / groups; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int oh, - int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + const int O_per_group = O / groups; + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_wgt_jump_h = - std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = - std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_h = + std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = + std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; - } + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } - base_h[i] = wh_base; - } + base_h[i] = wh_base; + } - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; - } + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - base_w[j] = ww_base; - } + base_w[j] = ww_base; + } - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + + iw_dil * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oH_border_0 = 0; + int oH_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oH; + int oH_border_2 = std::max( + oH_border_1, + (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + int oW_border_0 = 0; + int oW_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oW; + int oW_border_2 = std::max( + oW_border_1, + (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - } // oh + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - } // n - }); + } // n + }); } template @@ -351,7 +359,8 @@ void slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -400,7 +409,8 @@ void slow_conv_3D( out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -415,9 +425,9 @@ void slow_conv_3D( int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; @@ -478,7 +488,7 @@ void slow_conv_3D( std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; + int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { @@ -490,7 +500,7 @@ void slow_conv_3D( } for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; + int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { @@ -502,7 +512,7 @@ void slow_conv_3D( } for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; + int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { @@ -521,9 +531,9 @@ void slow_conv_3D( int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; @@ -573,24 +583,30 @@ void slow_conv_3D( }; int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oD; int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + oD_border_1, + (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oH; int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + oH_border_1, + (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_1 = is_idil_one + ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) + : oW; int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + oW_border_1, + (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -658,7 +674,8 @@ void dispatch_slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -669,7 +686,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -680,7 +698,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -691,7 +710,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -707,7 +727,8 @@ void dispatch_slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -718,7 +739,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -729,7 +751,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -740,7 +763,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -756,7 +780,8 @@ void dispatch_slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -767,7 +792,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -778,7 +804,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -789,7 +816,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = padding[0] * in_padded.strides()[1]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = { + N, + iH + padding_lo[0] + padding_hi[0], + iW + padding_lo[1] + padding_hi[1], + C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = - padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1] + + padding_lo[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, @@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu( Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { - padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); @@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu( // Pick input slice from padded size_t data_offset = 0; - for (size_t i = 0; i < padding.size(); i++) { - data_offset += padding[i] * in_padded.strides()[i + 1]; + for (size_t i = 0; i < padding_lo.size(); i++) { + data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1261,7 +1297,8 @@ void conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1270,22 +1307,40 @@ void conv_1D_cpu( const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation, stream); + in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1295,18 +1350,35 @@ void conv_2D_cpu( if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } - return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_3D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1317,11 +1389,28 @@ void conv_3D_cpu( in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } } // namespace @@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index ae31a6cff..35ed3d44e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aa5e88b7..e8c260425 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3974,6 +3974,7 @@ array conv_general( to_stream(s), stride, padding_lo, + padding_hi, kernel_dilation, input_dilation, groups, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 7288a4885..03ca06bdd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1055,7 +1055,8 @@ array conv_weight_backward_patches( const array& wt, const array& cotan, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, StreamOrDevice s) { // Resolve Padded input shapes and strides Shape padding_starts(in.ndim(), 0); @@ -1064,9 +1065,9 @@ array conv_weight_backward_patches( // padded shape for (int i = 1; i < in.ndim() - 1; i++) { - in_padded_shape[i] += 2 * padding[i - 1]; - padding_ends[i] += padding[i - 1]; - padding_starts[i] += padding[i - 1]; + in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1]; + padding_ends[i] += padding_lo[i - 1]; + padding_starts[i] += padding_lo[i - 1]; } // padded strides (contiguous) @@ -1078,9 +1079,16 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_(padding.begin(), padding.end()); - auto in_padded = pad( - in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); + Shape padding_lo_(padding_lo.begin(), padding_lo.end()); + Shape padding_hi_(padding_hi.begin(), padding_hi.end()); + auto in_padded = + pad(in, + padded_axes, + padding_lo_, + padding_hi_, + array(0, in.dtype()), + "constant", + s); // Resolve strided patches @@ -1147,16 +1155,16 @@ std::vector Convolution::vjp( for (int a : argnums) { // Grads for input if (a == 0) { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_lo[i] = wt_size - padding_[i] - 1; + padding_lo[i] = wt_size - padding_lo_[i] - 1; int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding_[i]; + padding_hi[i] = in_size - out_size + padding_hi_[i]; } // Check for negative padding @@ -1226,18 +1234,12 @@ std::vector Convolution::vjp( if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( - in, wt, cotan, kernel_strides_, padding_, stream()); + in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1283,7 +1285,8 @@ std::pair, std::vector> Convolution::vmap( in, w, kernel_strides_, - padding_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups, @@ -1332,7 +1335,8 @@ std::pair, std::vector> Convolution::vmap( bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); - return padding_ == c_other.padding_ && + return padding_lo_ == c_other.padding_lo_ && + padding_hi_ == c_other.padding_hi_ && kernel_strides_ == c_other.kernel_strides_ && kernel_dilation_ == c_other.kernel_dilation_ && input_dilation_ == c_other.input_dilation_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 3753e43c5..2caed8477 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive { explicit Convolution( Stream stream, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& kernel_dilation, const std::vector& input_dilation, const int groups = 1, const bool flip = false) : UnaryPrimitive(stream), - padding_(padding), + padding_lo_(padding_lo), + padding_hi_(padding_hi), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), @@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive { } private: - std::vector padding_; + std::vector padding_lo_; + std::vector padding_hi_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..35dcf42ac 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + if __name__ == "__main__": unittest.main() From 6661387066b38ef7221d29d7dad6c25d07d6e96a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:25:12 -0700 Subject: [PATCH 033/156] Fix fft for integer overflow (#2161) --- mlx/backend/metal/fft.cpp | 4 +--- mlx/backend/metal/kernels/fft/readwrite.h | 28 ++++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 011eb7ebb..1e23160a6 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -632,7 +632,7 @@ void fft_op( func_consts.push_back(make_int(&rader_m, 3)); // The overall number of FFTs we're going to compute for this input - int size = out.dtype() == float32 ? out.size() : in.size(); + size_t size = out.dtype() == float32 ? out.size() : in.size(); if (real && inverse && four_step_params.required) { size = out.size(); } @@ -659,8 +659,6 @@ void fft_op( // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } - int out_buffer_size = out.size(); - auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index f6724820d..0dc62992e 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; From 659a51919fd3d70798e91e9e112075680b95556e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:35:14 -0700 Subject: [PATCH 034/156] patch bump (#2162) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index 8340e1e8c..c573c45c9 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_PATCH 2 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From caaa3f1f8ceac3faee5068c04ea0e574af24f829 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Sun, 11 May 2025 15:03:47 +0200 Subject: [PATCH 035/156] Small typos in mx.metal deprecations (#2176) --- python/src/metal.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/src/metal.cpp b/python/src/metal.cpp index a13dd2a03..54642409c 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -49,21 +49,21 @@ void init_metal(nb::module_& m) { metal.def( "set_memory_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit"); + DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit"); return mx::set_memory_limit(limit); }, "limit"_a); metal.def( "set_cache_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit"); + DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit"); return mx::set_cache_limit(limit); }, "limit"_a); metal.def( "set_wired_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit"); + DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit"); return mx::set_wired_limit(limit); }, "limit"_a); From 8f3d208dcef00c5085dd3acfde2a6abb18585f07 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 12 May 2025 10:48:57 -0700 Subject: [PATCH 036/156] Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177) * handle hadamard and addmm on empty inputs * fix --- mlx/backend/cpu/matmul.cpp | 8 +++++++- mlx/backend/metal/matmul.cpp | 17 +++++++++++++++++ mlx/ops.cpp | 8 ++++++++ python/tests/test_blas.py | 17 +++++++++++++++++ python/tests/test_ops.py | 3 +++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 8ae99ab2d..b944aacc0 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[AddMM::eval_cpu] Currently only supports float32."); } + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } // Fill output with C auto& c = inputs[2]; @@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(c, out, ctype, stream()); - + if (inputs[0].shape(-1) == 0) { + return; + } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 71221f8d9..e0ff44200 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -716,6 +716,23 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } + + // Return 0s if either input is empty + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + // Copy c into out and return + if (inputs[0].shape(-1) == 0) { + copy_gpu( + inputs[2], + out, + inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + return; + } + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e8c260425..922680110 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -472,6 +472,10 @@ array hadamard_transform( const array& a, std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[hadamard_transform] Does not support empty arrays."); + } // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) int n = a.ndim() > 0 ? a.shape(-1) : 1; float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); @@ -4326,6 +4330,10 @@ array addmm( c = reshape(c, c_reshape, s); } + if (c.shape() != out_shape) { + throw std::invalid_argument( + "[addmm] input c must broadcast to the output shape"); + } auto out = array( std::move(out_shape), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 6fca4885b..df459eadc 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -589,6 +589,10 @@ class TestBlas(mlx_tests.MLXTestCase): alpha = 0.5 beta = 2.0 + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) @@ -745,6 +749,19 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) + c = mx.array(1.0, dtype=mx.float32) + a = mx.array([], dtype=mx.float32) + b = mx.array([], dtype=mx.float32) + out = mx.addmm(c, a, b) + self.assertEqual(out.item(), 1.0) + self.assertEqual(out.shape, ()) + + a = mx.zeros(shape=(5, 0)) + b = mx.zeros(shape=(0, 5)) + c = mx.random.uniform(shape=(5, 5)) + out = mx.addmm(c, a, b) + self.assertTrue(mx.allclose(out, c)) + def test_block_masked_matmul(self): def ref_block_masked_mm( a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d9e143d82..0921de788 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2830,6 +2830,9 @@ class TestOps(mlx_tests.MLXTestCase): return H def test_hadamard(self): + with self.assertRaises(ValueError): + mx.hadamard_transform(mx.array([])) + h28_str = """ +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++- From 3aa9cf3f9ed7e1dd508b0d98b07834f5ac5c43cf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 14:27:53 -0700 Subject: [PATCH 037/156] Fix put_along_axis for empty arrays (#2181) --- mlx/ops.cpp | 4 ++++ python/tests/test_ops.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 922680110..0c18cccfe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3175,6 +3175,10 @@ array scatter_axis( throw std::invalid_argument(msg.str()); } + if (a.size() == 0) { + return a; + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0921de788..f3d48dda3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1255,6 +1255,12 @@ class TestOps(mlx_tests.MLXTestCase): np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + a = mx.array([], mx.float32) + b = mx.put_along_axis(a, a, a, axis=None) + mx.eval(b) + self.assertEqual(b.size, 0) + self.assertEqual(b.shape, a.shape) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3) From eca2f3eb974b86d37da170023040c5ac9a148c18 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 14 May 2025 09:09:56 +0900 Subject: [PATCH 038/156] Add remove_index utility (#2173) --- mlx/backend/common/utils.h | 7 +++++++ mlx/backend/cpu/arg_reduce.cpp | 6 ++---- mlx/backend/cpu/indexing.cpp | 28 ++++++++++------------------ mlx/backend/metal/indexing.cpp | 29 +++++++---------------------- 4 files changed, 26 insertions(+), 44 deletions(-) diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 20a65d7b1..a4bdaa5ca 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -165,4 +165,11 @@ void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out); + +template +inline std::vector remove_index(std::vector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index a8ba3efe2..66468912d 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -14,10 +14,8 @@ template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; - Strides strides = in.strides(); - Shape shape = in.shape(); - strides.erase(strides.begin() + axis); - shape.erase(shape.begin() + axis); + Strides strides = remove_index(in.strides(), axis); + Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); auto out_ptr = out.data(); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 70d6b3eb7..5f99093e5 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -257,15 +257,11 @@ void gather_axis( const array& ind, array& out, const int axis) { - auto strides = ind.strides(); - strides.erase(strides.begin() + axis); - auto shape = ind.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator ind_it(shape, strides, src.ndim() - 1); - - strides = src.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator src_it(shape, strides, src.ndim() - 1); + auto shape = remove_index(ind.shape(), axis); + ContiguousIterator ind_it( + shape, remove_index(ind.strides(), axis), src.ndim() - 1); + ContiguousIterator src_it( + shape, remove_index(src.strides(), axis), src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); @@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { template void scatter_axis(array& out, const array idx, const array& upd, int axis) { - auto strides = idx.strides(); - strides.erase(strides.begin() + axis); - auto shape = idx.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); - - strides = upd.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); + auto shape = remove_index(idx.shape(), axis); + ContiguousIterator idx_it( + shape, remove_index(idx.strides(), axis), upd.ndim() - 1); + ContiguousIterator upd_it( + shape, remove_index(upd.strides(), axis), upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cccfd908a..d2a601b1e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,6 +2,7 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" @@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = src.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(src.shape(axis_), 8); @@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = upd.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); From 0751263dec5a210eb2ba097c108e8d78aa58124c Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 14 May 2025 12:19:54 +0900 Subject: [PATCH 039/156] Fix typo in row_reduce_small (#2179) --- mlx/backend/metal/kernels/reduction/reduce_row.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index c8973429f..936d75bb5 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { From 130df35e1b520061a053c052fba07122dc390c6a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 22:43:45 -0700 Subject: [PATCH 040/156] Add random normal distribution for complex numbers (#2182) --- mlx/random.cpp | 45 +++++++++++++++++++++++++++++-------- mlx/random.h | 18 ++++++++++++--- python/src/random.cpp | 35 +++++++++++++++++++---------- python/tests/test_random.py | 35 +++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 24 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 89a027b17..6c6d1eb95 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -176,24 +176,51 @@ array uniform( array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); } +inline array complex_normal( + Shape shape, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s) { + auto stream = to_stream(s); + auto low = above_minus_one_with_default(float32); + auto high = array(1.0f, float32); + shape.push_back(2); + auto samples = + erfinv(uniform(low, high, shape, float32, key, stream), stream); + samples = squeeze(view(samples, complex64, stream), -1, stream); + if (scale.has_value()) { + samples = multiply(*scale, samples, stream); + } + if (loc.has_value()) { + samples = add(*loc, samples, stream); + } + return samples; +} + array normal( const Shape& shape, Dtype dtype, - const float loc /* = 0.0 */, - const float scale /* = 1.0 */, - const std::optional& key /*= nullopt */, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, StreamOrDevice s /* = {} */) { + if (dtype == complex64) { + return complex_normal(shape, loc, scale, key, s); + } + auto stream = to_stream(s); auto low = above_minus_one_with_default(dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - samples = - multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); - if (scale != 1.0) { - samples = multiply(array(scale, dtype), samples, stream); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } - if (loc != 0.0) { - samples = add(array(loc, dtype), samples, stream); + samples = multiply(applied_scale, erfinv(samples, stream), stream); + if (loc.has_value()) { + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } diff --git a/mlx/random.h b/mlx/random.h index b2c821736..0dfdab7a1 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -94,12 +94,24 @@ inline array uniform( /** Generate samples from the standard normal distribution. */ array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, - StreamOrDevice s = {}); + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} inline array normal( const Shape& shape, const float loc, @@ -113,13 +125,13 @@ inline array normal( const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, dtype, 0.0, 1.0, key, s); + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); } inline array normal( const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, 0.0, 1.0, key, s); + return normal(shape, float32, std::nullopt, std::nullopt, key, s); } /** Generate samples from a multivariate normal distribution. **/ diff --git a/python/src/random.cpp b/python/src/random.cpp index 22b706174..837f91616 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) { "normal", [](const mx::Shape& shape, std::optional type, - float loc, - float scale, + const std::optional& loc_, + const std::optional& scale_, const std::optional& key_, mx::StreamOrDevice s) { + auto dtype = type.value_or(mx::float32); auto key = key_ ? key_.value() : default_key().next(); - return mx::random::normal( - shape, type.value_or(mx::float32), loc, scale, key, s); + auto loc = + loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt; + auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype)) + : std::nullopt; + return mx::random::normal(shape, dtype, loc, scale, key, s); }, "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, - "loc"_a = 0.0, - "scale"_a = 1.0, + "loc"_a = nb::none(), + "scale"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate normally distributed random numbers. + If ``loc`` and ``scale`` are not provided the "standard" normal + distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for + real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0, + \frac{1}{2})$ for complex numbers. + Args: - shape (list(int), optional): Shape of the output. Default is ``()``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. - loc (float, optional): Mean of the distribution. Default is ``0.0``. - scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. + loc (scalar or array, optional): Mean of the distribution. + Default: ``None``. + scale (scalar or array, optional): Standard deviation of the + distribution. Default: ``None``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..2fc768651 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal( + (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j + ) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + def test_broadcastable_scale_loc(self): + b = mx.random.normal((10, 2)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (2, 10, 2)) + + with self.assertRaises(ValueError): + b = mx.random.normal((10,)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": unittest.main() From cf6c939e868f6db3421396fda3fde31708e6f1eb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 14 May 2025 23:37:12 -0700 Subject: [PATCH 041/156] Fix some complex vjps (#2178) --- mlx/primitives.cpp | 89 ++++++++++++++++++++++++++++++---------- python/tests/test_fft.py | 57 +++++++++++++++++++++++++ tests/autograd_tests.cpp | 46 +++++++++++++++------ tests/fft_tests.cpp | 16 ++++---- 4 files changed, 166 insertions(+), 42 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 03ca06bdd..e1924e66c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1488,14 +1488,16 @@ std::vector Divide::vjp( const std::vector& argnums, const std::vector&) { std::vector vjps; + array denominator_bar = conjugate(primals[1], stream()); for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotangents[0], primals[1], stream())); + vjps.push_back(divide(cotangents[0], denominator_bar, stream())); } else { vjps.push_back(negative( divide( - multiply(cotangents[0], primals[0], stream()), - square(primals[1], stream()), + multiply( + cotangents[0], conjugate(primals[0], stream()), stream()), + square(denominator_bar, stream()), stream()), stream())); } @@ -1950,30 +1952,74 @@ std::vector FFT::vjp( assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); + + // TODO: Add it as an option to do an unnormalized or scaled fft so that this + // isn't part of the graph. + double n_elements = 1; + for (auto ax : axes) { + n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); + } + if (real_ && inverse_) { - auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = Shape(out.ndim(), 0); - auto stop = in.shape(); - out = slice(out, start, stop, stream()); - auto mask_shape = out.shape(); - mask_shape[axes_.back()] -= 2; - auto mask = full(mask_shape, 2.0f, stream()); - auto pad_shape = out.shape(); - pad_shape[axes_.back()] = 1; - auto pad = full(pad_shape, 1.0f, stream()); - mask = concatenate({pad, mask, pad}, axes_.back(), stream()); - return {multiply(mask, out, stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets doubled. + int N = in.shape(axes_.back()); + bool odd = cotangents[0].shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + two, + one, + stream()); + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; } else if (real_) { Shape n; for (auto ax : axes_) { - n.push_back(in.shape()[ax]); + n.push_back(in.shape(ax)); } - return {astype( - fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + int N = cotangents[0].shape(axes_.back()); + bool odd = in.shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + half, + one, + stream()); + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { - return {fft::ifftn(cotangents[0], axes, stream())}; + return {multiply( + fft::fftn(cotangents[0], axes, stream()), + array(1 / n_elements, complex64), + stream())}; } else { - return {fft::fftn(cotangents[0], axes, stream())}; + return {multiply( + fft::ifftn(cotangents[0], axes, stream()), + array(n_elements, complex64), + stream())}; } } @@ -2776,7 +2822,8 @@ std::vector Multiply::vjp( const std::vector&) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); + vjps.push_back(multiply( + conjugate(primals[1 - arg], stream()), cotangents[0], stream())); } return vjps; } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index f644944c7..df9d25edc 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -7,6 +7,13 @@ import mlx.core as mx import mlx_tests import numpy as np +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + class TestFFT(mlx_tests.MLXTestCase): def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): @@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase): x = mx.array([]) self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_fft_grads(self): + real = [True, False] + inverse = [True, False] + axes = [ + (-1,), + (-2, -1), + ] + shapes = [ + (4, 4), + (2, 4), + (2, 7), + (7, 7), + ] + + mxffts = { + (True, True): mx.fft.irfftn, + (True, False): mx.fft.rfftn, + (False, True): mx.fft.ifftn, + (False, False): mx.fft.fftn, + } + tffts = { + (True, True): torch.fft.irfftn, + (True, False): torch.fft.rfftn, + (False, True): torch.fft.ifftn, + (False, False): torch.fft.fftn, + } + + for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): + + def f(x): + y = mxffts[r, i](x) + return (mx.abs(y) ** 2).sum() + + def g(x): + y = tffts[r, i](x) + return (torch.abs(y) ** 2).sum() + + if r and not i: + x = mx.random.normal(sh) + else: + x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + fx = f(x) + gx = g(torch.tensor(x)) + self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) + + dfdx = mx.grad(f)(x) + dgdx = torch.func.grad(g)(torch.tensor(x)) + self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..5b3454bfc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") { } { + auto multiply_fn = + [](const std::vector& inputs) -> std::vector { + return {multiply(inputs[0], inputs[1])}; + }; + // Compute jvp auto x = array(complex64_t{2.0, 4.0}); auto y = array(3.0f); - auto x_tan = array(complex64_t{1.0, 2.0}); auto y_tan = array(2.0f); + auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{7.0, 14.0}); - auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; - CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); - - out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; - CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); - + // Compute vjp auto cotan = array(complex64_t{2.0, 3.0}); - out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; - CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), -8.0); + auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].dtype(), complex64); + CHECK_EQ(vjp_out[0].item(), complex64_t{6.0, 9.0}); + CHECK_EQ(vjp_out[1].dtype(), float32); + CHECK_EQ(vjp_out[1].item(), 16); + } - out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; - CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + { + auto divide_fn = + [](const std::vector& inputs) -> std::vector { + return {divide(inputs[0], inputs[1])}; + }; + + // Compute jvp + auto x = array(complex64_t{2.0, 3.0}); + auto y = array(complex64_t{1.0, 2.0}); + auto x_tan = array(complex64_t{3.0, 4.0}); + auto y_tan = array(complex64_t{4.0, -2.0}); + auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ( + jvp_out[0].item(), doctest::Approx(complex64_t{2.6, 2.8})); + + // Compute vjp + auto cotan = array(complex64_t{2.0, -4.0}); + auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{2.0, 0.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{-3.2, -0.4}); } } diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 0db3999c8..b9e2d1bcc 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -243,7 +243,7 @@ TEST_CASE("test fft grads") { auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item()); auto tangent = astype(arange(10), complex64); auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; @@ -252,7 +252,7 @@ TEST_CASE("test fft grads") { // Inverse auto ifft_fn = [](array x) { return fft::ifft(x); }; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item()); jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); @@ -261,7 +261,8 @@ TEST_CASE("test fft grads") { auto rfft_fn = [](array x) { return fft::rfft(x); }; cotangent = astype(arange(6), complex64); vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; - auto expected = astype(fft::fft(cotangent, 10, 0), float32); + array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64); + auto expected = fft::irfft(cotangent * mask, 10, 0) * 10; CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), float32); @@ -272,12 +273,9 @@ TEST_CASE("test fft grads") { auto irfft_fn = [](array x) { return fft::irfft(x); }; cotangent = astype(arange(10), float32); vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; - expected = fft::fft(cotangent, 10, 0); - auto o_splits = split(vjp_out, {1, 5}); - auto e_splits = split(expected, {1, 5, 6}); - CHECK_EQ(e_splits[0].item(), o_splits[0].item()); - CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); - CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32); + expected = fft::rfft(cotangent) * mask; + CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), complex64); jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second; From c1eb9d05d98a16e1e22f5c9b5c683d50c4188e54 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 13:01:44 -0700 Subject: [PATCH 042/156] non-symmetric eig and eigh (#2188) --- docs/src/python/linalg.rst | 2 + mlx/backend/cpu/CMakeLists.txt | 1 + mlx/backend/cpu/eig.cpp | 174 ++++++++++++++++++++++++++++++ mlx/backend/cpu/lapack.h | 1 + mlx/backend/metal/primitives.cpp | 8 +- mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/export.cpp | 1 + mlx/linalg.cpp | 26 ++++- mlx/linalg.h | 4 + mlx/primitives.cpp | 37 +++++++ mlx/primitives.h | 23 ++++ python/src/linalg.cpp | 72 ++++++++++++- python/tests/test_linalg.py | 77 +++++++++++++ 14 files changed, 423 insertions(+), 5 deletions(-) create mode 100644 mlx/backend/cpu/eig.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index b01f74117..495380c46 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,6 +16,8 @@ Linear Algebra cross qr svd + eigvals + eig eigvalsh eigh lu diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 96b3f1313..9d322c4c4 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -46,6 +46,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp new file mode 100644 index 000000000..c89003fc0 --- /dev/null +++ b/mlx/backend/cpu/eig.cpp @@ -0,0 +1,174 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void eig_impl( + array& a, + array& vectors, + array& values, + bool compute_eigenvectors, + Stream stream) { + using OT = std::complex; + auto a_ptr = a.data(); + auto eig_ptr = values.data(); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(values); + OT* vec_ptr = nullptr; + if (compute_eigenvectors) { + encoder.set_output_array(vectors); + vec_ptr = vectors.data(); + } + encoder.dispatch([a_ptr, + vec_ptr, + eig_ptr, + compute_eigenvectors, + N = vectors.shape(-1), + size = vectors.size()]() mutable { + // Work query + char jobr = 'N'; + char jobl = compute_eigenvectors ? 'V' : 'N'; + int n_vecs_r = 1; + int n_vecs_l = compute_eigenvectors ? N : 1; + int lwork = -1; + int info; + { + T work; + int iwork; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &info); + lwork = static_cast(work); + } + + auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; + auto vec_tmp_data = + array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)}; + auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); + auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + for (size_t i = 0; i < size / (N * N); ++i) { + geev( + &jobl, + &jobr, + &N, + a_ptr, + &N, + eig_tmp, + eig_tmp + N, + vec_tmp, + &n_vecs_l, + nullptr, + &n_vecs_r, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + &info); + for (int i = 0; i < N; ++i) { + eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]}; + } + if (vec_ptr) { + for (int i = 0; i < N; ++i) { + if (eig_ptr[i].imag() != 0) { + // This vector and the next are a pair + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = { + vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]}; + vec_ptr[(i + 1) * N + j] = { + vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]}; + } + i += 1; + } else { + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0}; + } + } + } + vec_ptr += N * N; + } + a_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + } + }); + encoder.add_temporary(a); +} + +} // namespace + +void Eig::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), complex64, nullptr, {}); + + auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); + copy( + a, + a_copy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + + values.set_data(allocator::malloc(values.nbytes())); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim - 1], strides[ndim - 2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.set_data( + allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags); + } + switch (a.dtype()) { + case float32: + eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); + break; + default: + throw std::runtime_error("[Eig::eval_cpu] only supports float32."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 2911c63f8..411742d56 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -45,6 +45,7 @@ INSTANTIATE_LAPACK_TYPES(geqrf) INSTANTIATE_LAPACK_TYPES(orgqr) INSTANTIATE_LAPACK_TYPES(syevd) +INSTANTIATE_LAPACK_TYPES(geev) INSTANTIATE_LAPACK_TYPES(potrf) INSTANTIATE_LAPACK_TYPES(gesvdx) INSTANTIATE_LAPACK_TYPES(getrf) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 860e9ddd7..6e42b29c9 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -378,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eig::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); +} + void Eigh::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); + throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 84372b096..1a180bfe0 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -55,6 +55,7 @@ NO_CPU(DynamicSlice) NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(Eig) NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 6826c97f6..676a6e550 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -126,6 +126,7 @@ NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) +NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { diff --git a/mlx/export.cpp b/mlx/export.cpp index c9139e156..bd2f24ba2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -331,6 +331,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(SVD), SERIALIZE_PRIMITIVE(Inverse), SERIALIZE_PRIMITIVE(Cholesky), + SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), SERIALIZE_PRIMITIVE(AffineQuantize), SERIALIZE_PRIMITIVE(RMSNorm), diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 53f13486a..e0f4ec2e6 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -488,7 +488,7 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh( +void validate_eig( const array& a, const StreamOrDevice& stream, const std::string fname) { @@ -511,7 +511,7 @@ array eigvalsh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigvalsh]"); + validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); return array( std::move(out_shape), @@ -524,7 +524,7 @@ std::pair eigh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigh]"); + validate_eig(a, s, "[linalg::eigh]"); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, @@ -533,6 +533,26 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +array eigvals(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eigvals]"); + Shape out_shape(a.shape().begin(), a.shape().end() - 1); + return array( + std::move(out_shape), + complex64, + std::make_shared(to_stream(s), false), + {a}); +} + +std::pair eig(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eig]"); + auto out = array::make_arrays( + {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {complex64, complex64}, + std::make_shared(to_stream(s), true), + {a}); + return std::make_pair(out[0], out[1]); +} + void validate_lu( const array& a, const StreamOrDevice& stream, diff --git a/mlx/linalg.h b/mlx/linalg.h index 8c3a2070a..0690fba95 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -99,6 +99,10 @@ array cross( int axis = -1, StreamOrDevice s = {}); +std::pair eig(const array& a, StreamOrDevice s = {}); + +array eigvals(const array& a, StreamOrDevice s = {}); + array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e1924e66c..87b2bc924 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -875,6 +875,43 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eig::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eig(a, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvals(a, stream())}; + } + + return {outputs, std::vector(outputs.size(), ax)}; +} + +std::vector Eig::output_shapes(const std::vector& inputs) { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return { + std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {std::move(shape)}; // Only eigenvalues + } +} + +bool Eig::is_equivalent(const Primitive& other) const { + auto& e_other = static_cast(other); + return compute_eigenvectors_ == e_other.compute_eigenvectors_; +} + std::pair, std::vector> Eigh::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 2caed8477..c0fbfc84d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2381,6 +2381,29 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + class Eigh : public Primitive { public: explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 3bc0e5b1b..cc8e79db6 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -236,7 +236,7 @@ void init_linalg(nb::module_& parent_module) { Returns: Union[tuple(array, ...), array]: - If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( @@ -407,6 +407,76 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvals", + &mx::linalg::eigvals, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues of a square matrix. + + This function differs from :func:`numpy.linalg.eigvals` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues (not necessarily in order). + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu) + >>> eigenvalues + array([3+0j, -1+0j], dtype=complex64) + )pbdoc"); + m.def( + "eig", + [](const mx::array& a, mx::StreamOrDevice s) { + auto result = mx::linalg::eig(a, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a square matrix. + + This function differs from :func:`numpy.linalg.eig` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: + A tuple containing the eigenvalues and the normalized right + eigenvectors. The column ``v[:, i]`` is the eigenvector + corresponding to the i-th eigenvalue. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eig(A, stream=mx.cpu) + >>> w + array([3+0j, -1+0j], dtype=complex64) + >>> v + array([[0.707107+0j, 0.707107+0j], + [-0.707107+0j, 0.707107+0j]], dtype=complex64) + )pbdoc"); + m.def( "eigvalsh", &mx::linalg::eigvalsh, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index a9fe572af..f65da1ff7 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -312,6 +312,83 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eig(self): + tols = {"atol": 1e-5, "rtol": 1e-5} + + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eig(A, stream=mx.cpu, **kwargs) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) + eig_vals_only = mx.linalg.eigvals(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 matrix + A_np = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test complex eigenvalues + A_np = np.array([[1.0, -1.0], [1.0, 1.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test with batched input + A_np = np.random.randn(3, n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eig( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + with self.assertRaises(ValueError): + mx.linalg.eigvals(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvals( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + def test_lu(self): + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array(0.0), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) + + # Test 3x3 matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + # Test batch dimension + a = mx.broadcast_to(a, (5, 5, 3, 3)) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + L = mx.take_along_axis(L, P[..., None], axis=-2) + self.assertTrue(mx.allclose(L @ U, a)) + + # Test non-square matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5} From a2cadb8218a6b350557a1a06954b65834e6cd446 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 18:17:50 -0700 Subject: [PATCH 043/156] real and imag properties (#2189) --- docs/src/python/array.rst | 2 ++ python/src/array.cpp | 12 ++++++++++++ python/tests/test_array.py | 9 +++++++++ 3 files changed, 23 insertions(+) diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 7e1c3339d..e68524d5a 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -19,6 +19,8 @@ Array array.ndim array.shape array.size + array.real + array.imag array.abs array.all array.any diff --git a/python/src/array.cpp b/python/src/array.cpp index 5f8dbe021..5ba0aaedc 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -319,6 +319,18 @@ void init_array(nb::module_& m) { R"pbdoc( The array's :class:`Dtype`. )pbdoc") + .def_prop_ro( + "real", + [](const mx::array& a) { return mx::real(a); }, + R"pbdoc( + The real part of a complex array. + )pbdoc") + .def_prop_ro( + "imag", + [](const mx::array& a) { return mx::imag(a); }, + R"pbdoc( + The imaginary part of a complex array. + )pbdoc") .def( "item", &to_scalar, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 792e666d6..e63da17df 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2022,6 +2022,15 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.add(y, x) + def test_real_imag(self): + x = mx.array([1.0]) + self.assertEqual(x.real.item(), 1.0) + self.assertEqual(x.imag.item(), 0.0) + + x = mx.array([1.0 + 1.0j]) + self.assertEqual(x.imag.item(), 1.0) + self.assertEqual(x.real.item(), 1.0) + if __name__ == "__main__": unittest.main() From 602f43e3d1f75a1036a3008024afa8f27c3140d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 19:20:36 -0700 Subject: [PATCH 044/156] fix conv grad (#2187) --- mlx/primitives.cpp | 18 +++++++++++------- python/tests/test_conv.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 87b2bc924..c2bb59c05 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1116,13 +1116,11 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_lo_(padding_lo.begin(), padding_lo.end()); - Shape padding_hi_(padding_hi.begin(), padding_hi.end()); auto in_padded = pad(in, padded_axes, - padding_lo_, - padding_hi_, + Shape(padding_lo), + Shape(padding_hi), array(0, in.dtype()), "constant", s); @@ -1274,8 +1272,14 @@ std::vector Convolution::vjp( in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_lo_; - std::vector padding_hi = padding_hi_; + auto padding_hi = padding_lo_; + + for (int i = 0; i < padding_hi.size(); ++i) { + int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); + int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1; + } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1284,7 +1288,7 @@ std::vector Convolution::vjp( /* const array& input = */ in_trans, /* const array& weight = */ cotan_trans, /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_lo = */ padding_lo_, /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_strides_, /* std::vector input_dilation = */ input_dilation_, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 35dcf42ac..7d63e4751 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1130,6 +1130,28 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + def test_basic_grad_shapes(self): + def loss_fn(kernel, inputs, strides, groups): + return mx.sum( + mx.conv_general( + inputs, + kernel, + stride=strides, + groups=groups, + ) + ) + + for in_shape, k_shape, strides, groups in [ + ((3, 5, 4), (6, 2, 2), (2,), 2), + ((3, 5, 4), (24, 2, 1), (2,), 4), + ((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2), + ((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4), + ]: + grads = mx.grad(loss_fn)( + mx.zeros(k_shape), mx.zeros(in_shape), strides, groups + ) + self.assertEqual(grads.shape, k_shape) + if __name__ == "__main__": unittest.main() From 7ff5c41e061a27265e0fe793dfc5dda3f4b55e46 Mon Sep 17 00:00:00 2001 From: Jack Wind Date: Fri, 16 May 2025 03:28:03 -0400 Subject: [PATCH 045/156] Add set_threadgroup_memory_length to CommandEncoder (#2183) --- mlx/backend/metal/device.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 26c9a0a28..660ba65e2 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -95,6 +95,10 @@ struct CommandEncoder { return enc_->setBytes(&v, sizeof(T), idx); } + void set_threadgroup_memory_length(size_t length, int idx) { + enc_->setThreadgroupMemoryLength(length, idx); + } + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } From 7d4b378952489b5c19b8d3ca5c028bf46a6ae86c Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 16 May 2025 22:44:42 +0900 Subject: [PATCH 046/156] Include cuda_bf16.h for bfloat16 overloads (#2192) * Include cuda_bf16.h for bfloat16 overloads * Add NO_GPU_MULTI(Eig) in cuda backend --- mlx/backend/cuda/kernels/fp16_math.cuh | 33 +------------------------- mlx/backend/cuda/primitives.cu | 1 + 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index 931c55ff7..edbd953de 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,44 +2,13 @@ #pragma once +#include #include #include #include namespace mlx::core::cu { -/////////////////////////////////////////////////////////////////////////////// -// Missing C++ operator overrides for CUDA 7. -/////////////////////////////////////////////////////////////////////////////// - -#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - -#define MLX_DEFINE_BF16_OP(OP) \ - __forceinline__ __device__ __nv_bfloat16 operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -#define MLX_DEFINE_BF16_CMP(OP) \ - __forceinline__ __device__ bool operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -MLX_DEFINE_BF16_OP(+) -MLX_DEFINE_BF16_OP(-) -MLX_DEFINE_BF16_OP(*) -MLX_DEFINE_BF16_OP(/) -MLX_DEFINE_BF16_CMP(>) -MLX_DEFINE_BF16_CMP(<) -MLX_DEFINE_BF16_CMP(>=) -MLX_DEFINE_BF16_CMP(<=) - -#undef MLX_DEFINE_BF16_OP -#undef MLX_DEFINE_BF16_CMP - -#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index dc6edf606..defdc746a 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -140,6 +140,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { From 48ef3e74e27a3ea620adc5fb5ae22be15613e67f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 May 2025 08:38:49 -0700 Subject: [PATCH 047/156] reduce vjp for all and any (#2193) --- mlx/primitives.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c2bb59c05..5f2bfdda4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3548,7 +3548,7 @@ std::vector Reduce::vjp( } else { - throw std::runtime_error("Reduce type VJP not yet implemented."); + return {zeros_like(in, stream())}; } } From 0654543dcca1c69b4fa745eeee981fa8394dae89 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 18 May 2025 00:18:43 -0700 Subject: [PATCH 048/156] Add complex eigh (#2191) --- mlx/array.h | 4 + mlx/backend/cpu/eigh.cpp | 177 ++++++++++++++++++++++++++++-------- mlx/backend/cpu/lapack.h | 40 +++++--- mlx/linalg.cpp | 17 +++- python/tests/test_linalg.py | 7 ++ 5 files changed, 190 insertions(+), 55 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index d9fcfc58e..98eef2e33 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -224,6 +224,10 @@ class array { // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } ~Data() { d(buffer); } diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index b50f2c722..58d3634e8 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -12,6 +12,133 @@ namespace mlx::core { namespace { +template +struct EighWork {}; + +template +struct EighWork< + T, + typename std::enable_if::value>::type> { + using R = T; + + char jobz; + char uplo; + int N; + int lwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) { + T work; + int iwork; + syevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, T* values) { + syevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &liwork, + &info); + } +}; + +template <> +struct EighWork> { + using T = std::complex; + using R = float; + + char jobz; + char uplo; + int N; + int lwork; + int lrwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) { + T work; + R rwork; + int iwork; + heevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &rwork, + &lrwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work.real()); + lrwork = static_cast(rwork); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, R* values) { + heevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &lrwork, + static_cast(buffers[2].buffer.raw_ptr()), + &liwork, + &info); + if (jobz == 'V') { + // We have pre-transposed the vectors but we also must conjugate them + // when they are complex. + // + // We could vectorize this but it is so fast in comparison to heevd that + // it doesn't really matter. + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + *vectors = std::conj(*vectors); + vectors++; + } + } + } + } +}; + template void eigh_impl( array& vectors, @@ -19,8 +146,10 @@ void eigh_impl( const std::string& uplo, bool compute_eigenvectors, Stream stream) { + using R = typename EighWork::R; + auto vec_ptr = vectors.data(); - auto eig_ptr = values.data(); + auto eig_ptr = values.data(); char jobz = compute_eigenvectors ? 'V' : 'N'; auto& encoder = cpu::get_command_encoder(stream); @@ -33,49 +162,17 @@ void eigh_impl( N = vectors.shape(-1), size = vectors.size()]() mutable { // Work query - int lwork = -1; - int liwork = -1; - int info; - { - T work; - int iwork; - syevd( - &jobz, - &uplo, - &N, - nullptr, - &N, - nullptr, - &work, - &lwork, - &iwork, - &liwork, - &info); - lwork = static_cast(work); - liwork = iwork; - } + EighWork work(jobz, uplo, N); - auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; + // Work loop for (size_t i = 0; i < size / (N * N); ++i) { - syevd( - &jobz, - &uplo, - &N, - vec_ptr, - &N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - &liwork, - &info); + work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; - if (info != 0) { + if (work.info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << work.info; throw std::runtime_error(msg.str()); } } @@ -131,6 +228,10 @@ void Eigh::eval_cpu( eigh_impl( vectors, values, uplo_, compute_eigenvectors_, stream()); break; + case complex64: + eigh_impl>( + vectors, values, uplo_, compute_eigenvectors_, stream()); + break; default: throw std::runtime_error( "[Eigh::eval_cpu] only supports float32 or float64."); diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 411742d56..b242093ff 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -2,14 +2,14 @@ #pragma once -// Required for Visual Studio. -// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md -#ifdef _MSC_VER #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex -#endif +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include @@ -32,7 +32,7 @@ #endif -#define INSTANTIATE_LAPACK_TYPES(FUNC) \ +#define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ @@ -42,12 +42,24 @@ } \ } -INSTANTIATE_LAPACK_TYPES(geqrf) -INSTANTIATE_LAPACK_TYPES(orgqr) -INSTANTIATE_LAPACK_TYPES(syevd) -INSTANTIATE_LAPACK_TYPES(geev) -INSTANTIATE_LAPACK_TYPES(potrf) -INSTANTIATE_LAPACK_TYPES(gesvdx) -INSTANTIATE_LAPACK_TYPES(getrf) -INSTANTIATE_LAPACK_TYPES(getri) -INSTANTIATE_LAPACK_TYPES(trtri) +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(geev) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e0f4ec2e6..144f9a880 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) { } } +void check_float_or_complex(Dtype dtype, const std::string& prefix) { + if (dtype != float32 && dtype != float64 && dtype != complex64) { + std::ostringstream msg; + msg << prefix << " Arrays must have type float32, float64 or complex64. " + << "Received array with type " << dtype << "."; + throw std::invalid_argument(msg.str()); + } +} + Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } @@ -493,7 +502,7 @@ void validate_eig( const StreamOrDevice& stream, const std::string fname) { check_cpu_stream(stream, fname); - check_float(a.dtype(), fname); + check_float_or_complex(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -513,9 +522,10 @@ array eigvalsh( StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); return array( std::move(out_shape), - a.dtype(), + eigval_type, std::make_shared(to_stream(s), UPLO, false), {a}); } @@ -525,9 +535,10 @@ std::pair eigh( std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigh]"); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, - {a.dtype(), a.dtype()}, + {eigval_type, a.dtype()}, std::make_shared(to_stream(s), UPLO, true), {a}); return std::make_pair(out[0], out[1]); diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f65da1ff7..f5eeda837 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase): A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 check_eigs_and_vecs(A_np) + # Test with complex inputs + A_np = ( + np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1) + ) + A_np = A_np + A_np.T.conj() + check_eigs_and_vecs(A_np) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array From 8576e6fe3606bf5b805162fd5f4a7803a9a0d349 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 18 May 2025 06:05:11 -0700 Subject: [PATCH 049/156] fix conv2d bug + faster conv 1d (#2195) * fix conv2d bug + faster conv 1d * revert sort + flaky test --- mlx/backend/metal/conv.cpp | 268 +++++++++--------- .../steel/conv/loaders/loader_channel_l.h | 8 +- .../steel/conv/loaders/loader_channel_n.h | 4 +- mlx/ops.cpp | 6 +- python/tests/test_conv.py | 21 ++ python/tests/test_vmap.py | 1 + 6 files changed, 170 insertions(+), 138 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 35ed3d44e..6b4b70d47 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu( /*copies = */ copies); } -void conv_1D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const std::vector& padding, - const std::vector& wt_strides, - const std::vector& wt_dilation, - const std::vector& in_dilation, - int groups, - bool flip) { - // Make conv params - MLXConvParams<1> conv_params{ - /* const int N = */ static_cast(in.shape(0)), - /* const int C = */ static_cast(in.shape(2)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, - /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, - /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, - /* const int str[NDIM] = */ {wt_strides[0]}, - /* const int pad[NDIM] = */ {padding[0]}, - /* const int kdil[NDIM] = */ {wt_dilation[0]}, - /* const int idil[NDIM] = */ {in_dilation[0]}, - /* const size_t in_strides[NDIM + 2] = */ - {in.strides()[0], in.strides()[1], in.strides()[2]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2]}, - /* const int groups = */ groups, - /* const bool flip = */ flip}; - - // Direct to explicit gemm conv - if (groups > 1) { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } -} - -void slow_conv_2D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const MLXConvParams<2>& conv_params) { - int bm = 16, bn = 8; - int tm = 4, tn = 4; - - std::ostringstream kname; - kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn - << "_tm" << tm << "_tn" << tn; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; - - size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm); - size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn); - size_t grid_dim_z = conv_params.N; - - MTL::Size group_dims = MTL::Size(bm, bn, 1); - MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); - - compute_encoder.set_input_array(in, 0); - compute_encoder.set_input_array(wt, 1); - compute_encoder.set_output_array(out, 2); - - compute_encoder.set_bytes(conv_params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void implicit_gemm_conv_2D_gpu( const Stream& s, metal::Device& d, @@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void dispatch_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies) { + bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; + bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; + bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; + + if (is_idil_one && conv_params.groups > 1) { + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + } + + // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + bool channels_large = (conv_params.C + conv_params.O) >= 256; + if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + + // Direct to implicit gemm conv + if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && + (conv_params.O <= 16 || conv_params.O % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to explicit gemm conv + else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + +void conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + int groups, + bool flip, + std::vector& copies) { + bool is_idil_one = in_dilation[0] == 1; + int C = in.shape(2); + int O = wt.shape(0); + const int C_per_group = in.shape(2) / groups; + const int O_per_group = wt.shape(0) / groups; + + // Direct to implicit gemm conv + if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + MLXConvParams<2> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ C, + /* const int O = */ O, + /* const int iS[NDIM] = */ {static_cast(in.shape(1)), 1}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), 1}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1)), 1}, + /* const int str[NDIM] = */ {wt_strides[0], 1}, + /* const int pad[NDIM] = */ {padding[0], 0}, + /* const int kdil[NDIM] = */ {wt_dilation[0], 1}, + /* const int idil[NDIM] = */ {in_dilation[0], 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], 0, in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], 0, out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + return; + } + + // Make conv params + MLXConvParams<1> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(2)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, + /* const int str[NDIM] = */ {wt_strides[0]}, + /* const int pad[NDIM] = */ {padding[0]}, + /* const int kdil[NDIM] = */ {wt_dilation[0]}, + /* const int idil[NDIM] = */ {in_dilation[0]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + // Direct to explicit gemm conv + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -808,57 +865,7 @@ void conv_2D_gpu( /* const int groups = */ groups, /* const bool flip = */ flip, }; - - bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; - bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; - bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - - if (is_idil_one && groups > 1) { - const int C_per_group = conv_params.C / groups; - const int O_per_group = conv_params.O / groups; - - if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && - conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && - conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && - conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && - conv_params.wt_strides[1] == conv_params.wS[1] && - conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { - return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - if ((C_per_group <= 4 || C_per_group % 16 == 0) && - (O_per_group <= 16 || O_per_group % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } - } - - // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; - bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (!flip && is_stride_one && is_kdil_one && is_idil_one && - conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && - channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); - } - - // Direct to implicit gemm conv - if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && - (conv_params.O <= 16 || conv_params.O % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { - return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to explicit gemm conv - else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } void conv_3D_gpu( @@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_dilation_, input_dilation_, groups_, - flip_); + flip_, + copies); } // Throw error else { diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h index dad496e81..d52642b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index 56027916e..b0b98d21a 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0c18cccfe..a72c2bc85 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3584,21 +3584,21 @@ Shape conv_out_shape( if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for " + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (kernel_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for " + msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (input_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid input dilation " << input_dilation << "for " + msg << "[conv] Invalid input dilation " << input_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 7d63e4751..9fe11286d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertEqual(grads.shape, k_shape) + def test_1d_conv_with_2d(self): + x = mx.random.uniform(shape=(2, 10, 16)) + y = mx.random.normal(shape=(16, 3, 16)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + + x = mx.random.uniform(shape=(2, 10, 4)) + y = mx.random.normal(shape=(4, 3, 4)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index e571678d3..ddfceb0a1 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: From 237f9e58a892798aa9a4bfd6e83a864fa3358904 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 19 May 2025 22:10:44 +0900 Subject: [PATCH 050/156] Fix BEFORE keyword in target_include_directories (#2204) --- mlx/backend/cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 54d651005..f9695f66a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -36,7 +36,7 @@ FetchContent_Declare( cccl URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") FetchContent_MakeAvailable(cccl) -target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") +target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") # Use fixed version of NVTX. FetchContent_Declare( From 0359bf02c99f4beff9a596431865d8211b654714 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 19 May 2025 11:23:38 -0700 Subject: [PATCH 051/156] Nearest upsample (#2202) --- python/mlx/nn/layers/upsample.py | 11 ++++++++++- python/tests/test_upsample.py | 11 ++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 1f2ffd3da..e6bd282af 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) + M = int(scale * N) + indices = mx.arange(M, dtype=mx.float32) + if M > N: + indices = (indices + 0.5) * (N / M) - 0.5 + indices = indices.round() + else: + indices = indices * (N / M) + shape = [1] * ndims + shape[dim] = -1 + return indices.astype(mx.uint32).reshape(shape) def _linear_indices(N, scale, align_corners, dim, ndims): diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 402c7b0ca..86f41b6e8 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase): align_corners=align_corner, )(in_mx) mode_pt = { + "nearest": "nearest", "linear": "bilinear", "cubic": "bicubic", }[mode] @@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase): in_pt, scale_factor=scale_factor, mode=mode_pt, - align_corners=align_corner, + align_corners=align_corner if mode != "nearest" else None, ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) self.assertEqual(out_pt.shape, out_mx.shape) @@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase): ((4, 4), (0.5, 0.5)), ((7, 7), (2.0, 2.0)), ((10, 10), (0.2, 0.2)), + ((10, 10), (0.3, 0.3)), ((11, 21), (3.0, 3.0)), ((11, 21), (3.0, 2.0)), ): - # only test linear and cubic interpolation - # there will be numerical difference in nearest - # due to different indices selection. - for mode in ("cubic", "linear"): + for mode in ("cubic", "linear", "nearest"): for align_corner in (False, True): + if mode == "nearest" and align_corner: + continue run_upsample( N, C, From eebe73001affcb424171e9d49657e508f70a9201 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 May 2025 13:10:44 -0700 Subject: [PATCH 052/156] fix large arg reduce (#2206) --- mlx/backend/metal/kernels/arg_reduce.metal | 20 +++++++++++--------- mlx/backend/metal/primitives.cpp | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 7f1075ad9..4a83d8e57 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -80,9 +80,10 @@ template const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { @@ -104,17 +105,18 @@ template // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. - auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; + auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); + auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value - uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; @@ -144,7 +146,7 @@ template } // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize, simd_size); + uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } @@ -154,7 +156,7 @@ template } // Finally write the output - if (lid == 0) { + if (lid.x == 0) { out[out_idx] = best.index; } } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6e42b29c9..705c3ea76 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -182,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { (thread_group_size + simd_size - 1) / simd_size * simd_size; assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); - size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + auto gd = get_2d_grid_dims(out.shape(), out.strides()); + MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); From ab8883dd55745a12d385b91ef26ea794d2a45bdb Mon Sep 17 00:00:00 2001 From: Clement Liaw Date: Tue, 20 May 2025 07:39:11 -0700 Subject: [PATCH 053/156] include mlx::core::version() symbols in the mlx static library (#2207) --- mlx/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 4ba9b33dd..ce921b276 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -21,7 +21,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) # Define MLX_VERSION only in the version.cpp file. -add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) +add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_link_libraries(mlx PRIVATE $) From 4cbe6052147420ee09ae855d57a81cef3467af15 Mon Sep 17 00:00:00 2001 From: Jack Wind Date: Tue, 20 May 2025 13:22:26 -0400 Subject: [PATCH 054/156] Feat: Allow per-target Metal debug flags (#2201) * feat: allow per-target Metal debug flags * formatting fix --- cmake/extension.cmake | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cmake/extension.cmake b/cmake/extension.cmake index 3270b0056..13db804a1 100644 --- a/cmake/extension.cmake +++ b/cmake/extension.cmake @@ -11,13 +11,14 @@ include(CMakeParseArguments) # Args: TARGET: Custom target to be added for the metal library TITLE: Name of # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency -# files (like headers) +# files (like headers) DEBUG: Boolean, if true, enables debug compile options +# for this specific library. If not provided, uses global MLX_METAL_DEBUG. # # clang format on macro(mlx_build_metallib) # Parse args - set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) + set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -26,6 +27,10 @@ macro(mlx_build_metallib) # Collect compile options set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + if(MLX_METAL_DEBUG OR MTLLIB_DEBUG) + set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only + -frecord-sources) + endif() # Prepare metallib build command add_custom_command( From 35c87741cf2450c96c6e52afead61eec81c45e2a Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 11:42:48 +0900 Subject: [PATCH 055/156] Build for compute capability 70 instead of 75 (#2209) --- mlx/backend/cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index f9695f66a..d62a69846 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -25,7 +25,7 @@ target_compile_options(mlx # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES - "75;80" + "70;80" CACHE STRING "CUDA architectures") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES From 7774b87cbda51c5e34f1471d8e76767350368a05 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 23:25:03 +0900 Subject: [PATCH 056/156] Remove redundant simd_sum in logsumexp (#2210) --- mlx/backend/metal/kernels/logsumexp.h | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index b6898e31e..93744e15d 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -134,10 +134,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } From 79071bfba4f012517859bcfdd9032123d16cc6b6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 23:25:16 +0900 Subject: [PATCH 057/156] Fix out-of-bounds default value in logsumexp/softmax (#2213) --- mlx/backend/metal/kernels/logsumexp.h | 4 ++-- mlx/backend/metal/kernels/softmax.h | 4 ++-- tests/ops_tests.cpp | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 93744e15d..c746050b3 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..6ea4ac732 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5e2bae5a0..8833424a6 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") { x = array({-inf, -inf}); CHECK_EQ(logsumexp(x).item(), -inf); + x = repeat(array(-inf), 5000); + CHECK_EQ(logsumexp(x).item(), -inf); + x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f); From 55b4062dd8c71d4499a430012b49f676da91818a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 21 May 2025 17:13:04 -0700 Subject: [PATCH 058/156] copyright in docs (#2214) --- docs/src/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/conf.py b/docs/src/conf.py index abc68c3a2..d9dd32ad1 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,7 +10,7 @@ import mlx.core as mx # -- Project information ----------------------------------------------------- project = "MLX" -copyright = "2023, MLX Contributors" +copyright = "2023, Apple" author = "MLX Contributors" version = ".".join(mx.__version__.split(".")[:3]) release = version From 54a71f270a671d2b31c493c98f27a49fe217a6f1 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 23 May 2025 22:14:58 +0900 Subject: [PATCH 059/156] Remove unused defines (#2217) --- CMakeLists.txt | 3 +++ mlx/backend/cuda/CMakeLists.txt | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab8aea443..4bf8d2d3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,9 @@ target_include_directories( mlx PUBLIC $ $) +# Do not add mlx_EXPORTS define for shared library. +set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") + FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d62a69846..7ebe68324 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,8 +16,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) -target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) - # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") From f76ee1ffd2f668c9de852eee28e623f068eb7d1f Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 29 May 2025 22:48:30 +0900 Subject: [PATCH 060/156] Move some dims utils to common (#2223) --- mlx/backend/common/utils.cpp | 108 ++++++++++++++++++ mlx/backend/common/utils.h | 25 ++++ mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/kernel_utils.cu | 26 +++++ .../{dtype_utils.cuh => kernel_utils.cuh} | 14 +++ mlx/backend/cuda/primitives.cu | 2 +- mlx/backend/cuda/utils.h | 2 + mlx/backend/metal/utils.cpp | 106 ++--------------- mlx/backend/metal/utils.h | 17 +-- 9 files changed, 186 insertions(+), 115 deletions(-) create mode 100644 mlx/backend/cuda/kernel_utils.cu rename mlx/backend/cuda/{dtype_utils.cuh => kernel_utils.cuh} (53%) diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 35bba9c63..08df53a8e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -1,9 +1,16 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" namespace mlx::core { +std::string get_primitive_string(Primitive* primitive) { + std::ostringstream op_t; + primitive->print(op_t); + return op_t.str(); +} + std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, @@ -101,4 +108,105 @@ std::pair collapse_contiguous_dims( return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); } +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == pow2) { + break; + } + } + return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]); +} + +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) { + // Dims with strides of 0 are ignored as they + // correspond to broadcasted dimensions + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor) { + // Compute the 2d grid dimensions such that the total size of the grid is + // divided by divisor. + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + + // No need to add this shape we can just remove it from the divisor. + if (divisor % shape[i] == 0) { + divisor /= shape[i]; + continue; + } + + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + + if (divisor > 1) { + if (grid_x % divisor == 0) { + grid_x /= divisor; + divisor = 1; + } else if (grid_y % divisor == 0) { + grid_y /= divisor; + divisor = 1; + } + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return std::make_tuple( + static_cast(grid_x), static_cast(grid_y), 1); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index a4bdaa5ca..40bc3efe4 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -2,12 +2,15 @@ #pragma once +#include #include #include "mlx/array.h" namespace mlx::core { +std::string get_primitive_string(Primitive* primitive); + inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; @@ -70,6 +73,28 @@ std::pair collapse_contiguous_dims( const array& a, int64_t size_cap = std::numeric_limits::max()); +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 2^pow2 +using Dims = std::tuple; +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); + +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); + +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor); + struct ContiguousIterator { inline void step() { int dims = shape_.size(); diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7ebe68324..2a8ef9963 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -11,6 +11,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu new file mode 100644 index 000000000..575af7cf6 --- /dev/null +++ b/mlx/backend/cuda/kernel_utils.cu @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core { + +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { + Dims dims = get_2d_grid_dims_common(shape, strides); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor) { + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh similarity index 53% rename from mlx/backend/cuda/dtype_utils.cuh rename to mlx/backend/cuda/kernel_utils.cuh index 9b7f8ba65..67ac47449 100644 --- a/mlx/backend/cuda/dtype_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -1,7 +1,13 @@ // Copyright © 2025 Apple Inc. +// This file includes host-only utilies for writing CUDA kernels, the difference +// from backend/cuda/kernels/utils.cuh is that the latter file only include +// device-only code. + #pragma once +#include "mlx/array.h" + #include #include #include @@ -32,4 +38,12 @@ struct CTypeToCudaType { template using cuda_type_t = typename CTypeToCudaType::type; +// Compute the grid and block dimensions, check backend/common/utils.h for docs. +dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); +dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); +dim3 get_2d_grid_dims( + const Shape& shape, + const Strides& strides, + size_t divisor); + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index defdc746a..d105a242b 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernels/arange.cuh" #include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/distributed/primitives.h" diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 58d508765..6eaec8984 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +// This file include utilies that are used by C++ code (i.e. .cpp files). + #pragma once #include diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 329d250df..978501835 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/utils.h" - -using namespace mlx; +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -59,109 +58,20 @@ std::string type_to_name(const array& a) { return type_to_name(a.dtype()); } -MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { - int pows[3] = {0, 0, 0}; - int sum = 0; - while (true) { - int presum = sum; - // Check all the pows - if (dim0 >= (1 << (pows[0] + 1))) { - pows[0]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim1 >= (1 << (pows[1] + 1))) { - pows[1]++; - sum++; - } - if (sum == 10) { - break; - } - if (dim2 >= (1 << (pows[2] + 1))) { - pows[2]++; - sum++; - } - if (sum == presum || sum == pow2) { - break; - } - } - return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) { + Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { - // Dims with strides of 0 are ignored as they - // correspond to broadcasted dimensions - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); + Dims dims = get_2d_grid_dims_common(shape, strides); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { - // Compute the 2d grid dimensions such that the total size of the grid is - // divided by divisor. - size_t grid_x = 1; - size_t grid_y = 1; - for (int i = 0; i < shape.size(); ++i) { - if (strides[i] == 0) { - continue; - } - - // No need to add this shape we can just remove it from the divisor. - if (divisor % shape[i] == 0) { - divisor /= shape[i]; - continue; - } - - if (grid_x * shape[i] < UINT32_MAX) { - grid_x *= shape[i]; - } else { - grid_y *= shape[i]; - } - - if (divisor > 1) { - if (grid_x % divisor == 0) { - grid_x /= divisor; - divisor = 1; - } else if (grid_y % divisor == 0) { - grid_y /= divisor; - divisor = 1; - } - } - } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { - throw std::runtime_error("Unable to safely factor shape."); - } - if (grid_y > grid_x) { - std::swap(grid_x, grid_y); - } - return MTL::Size( - static_cast(grid_x), static_cast(grid_y), 1); -} - -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); + Dims dims = get_2d_grid_dims_common(shape, strides, divisor); + return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index f9245a6d6..576fb9107 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -13,22 +13,9 @@ namespace mlx::core { std::string type_to_name(const Dtype& t); std::string type_to_name(const array& a); -// Compute the thread block dimensions which fit the given -// input dimensions. -// - The thread block dimensions will be powers of two -// - The thread block size will be less than 2^pow2 +// Compute the grid and block dimensions, check backend/common/utils.h for docs. MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); - -// Computes a 2D grid where each element is < UINT_MAX -// Assumes: -// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 -// - shape and strides correspond to a contiguous (no holes) but -// possibly broadcasted array MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); - -// Same as above but we do an implicit division with divisor. -// Basically, equivalent to factorizing -// Prod(s \forall s in shape if strides[s] > 0) / divisor. MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); @@ -58,8 +45,6 @@ inline void debug_set_primitive_buffer_label( #endif } -std::string get_primitive_string(Primitive* primitive); - template constexpr bool is_numeric_except_char = std::is_arithmetic_v && !std::is_same_v && !std::is_same_v && From 6ef2f67e7f97a1f3c6eb5f0ef26787b966a1200d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 30 May 2025 12:12:10 -0700 Subject: [PATCH 061/156] 5bit quants (#2226) * 5bit quants * 5bit quants --- mlx/backend/cpu/quantized.cpp | 53 +++-- mlx/backend/metal/kernels/quantized.h | 246 +++++++++++++++++----- mlx/backend/metal/kernels/quantized.metal | 1 + mlx/backend/metal/quantized.cpp | 4 +- mlx/fast.cpp | 6 +- python/tests/test_quantized.py | 8 +- python/tests/test_vmap.py | 2 + 7 files changed, 248 insertions(+), 72 deletions(-) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index f0ac9d57f..ee8e56cc0 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -13,9 +13,18 @@ namespace mlx::core { namespace { +inline constexpr short get_pack_factor(int bits, int wsize = 8) { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) { + auto power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template void extract_bits(const uint8_t* w_in, T* w_out) { - assert(bits == 3 || bits == 6); + static_assert(bits == 3 || bits == 5 || bits == 6); if (bits == 3) { w_out[0] = static_cast(w_in[0] & 0x7); w_out[1] = static_cast((w_in[0] & 0x38) >> 3); @@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) { w_out[5] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); w_out[6] = static_cast((w_in[2] & 0x1c) >> 2); w_out[7] = static_cast((w_in[2] & 0xe0) >> 5); + } else if (bits == 5) { + w_out[0] = static_cast(w_in[0] & 0x1f); + w_out[1] = static_cast(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3)); + w_out[2] = static_cast((w_in[1] & 0x7c) >> 2); + w_out[3] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1)); + w_out[4] = static_cast(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4)); + w_out[5] = static_cast((w_in[3] & 0x3e) >> 1); + w_out[6] = static_cast(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2)); + w_out[7] = static_cast((w_in[4] & 0xf8) >> 3); + } else if (bits == 6) { w_out[0] = static_cast(w_in[0] & 0x3f); w_out[1] = @@ -46,8 +65,8 @@ void _qmm( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + constexpr int pack_factor = get_pack_factor(bits, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { @@ -65,7 +84,7 @@ void _qmm( T scale = *scales_local++; T bias = *biases_local++; for (int ng = 0; ng < packs_in_group; ng++) { - if (bits == 3 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) @@ -104,8 +123,9 @@ void _qmm_t( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + + constexpr int pack_factor = get_pack_factor(bits, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(bits); constexpr int packs_in_group = group_size / pack_factor; for (int m = 0; m < M; m++) { @@ -121,7 +141,7 @@ void _qmm_t( T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw++) { - if (bits == 3 || bits == 6) { + if constexpr (bits == 3 || bits == 5 || bits == 6) { T wl[pack_factor]; extract_bits(w_local, wl); #pragma clang loop unroll(full) @@ -304,6 +324,10 @@ void _qmm_dispatch_typed( _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); break; + case 5: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; case 6: _qmm_dispatch_group( result, x, w, scales, biases, M, N, K, group_size, transposed_w); @@ -613,9 +637,8 @@ void quantize( float eps = 1e-7; bool power_of_2_bits = is_power_of_2(bits); - int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 - int bytes_per_pack = power_of_2_bits ? 1 : 3; + int el_per_int = get_pack_factor(bits, 32); + int bytes_per_pack = get_bytes_per_pack(bits); int int_per_group = group_size * bytes_per_pack / el_per_int; size_t n_groups = w_size / group_size; @@ -640,15 +663,21 @@ void quantize( } size_t out_idx = i * int_per_group; for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { - uint32_t out_el = 0; + uint64_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { float w_el = w[w_idx + j * el_per_int + k]; w_el = std::rint((w_el - bias) / scale); w_el = std::min(std::max(w_el, 0.0f), n_bins); - out_el |= static_cast(w_el) << (k * bits); + out_el |= static_cast(w_el) << (k * bits); } if (power_of_2_bits) { out[out_idx + j] = out_el; + } else if (bits == 5) { + out[out_idx + bytes_per_pack * j] = out_el & 0xff; + out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; + out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; + out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24; + out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32; } else { out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index ba4fb2426..fea6f1460 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -14,11 +14,23 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -153,8 +196,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -199,6 +243,26 @@ inline U qdot( } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -234,8 +298,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -280,6 +345,26 @@ inline U qdot_safe( } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -310,8 +395,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } - } else if (bits == 6) { + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -375,8 +484,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { @@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; @@ -452,11 +577,12 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; @@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -2120,11 +2247,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, @@ -2305,13 +2431,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -2354,8 +2480,8 @@ template biases[gindex] = static_cast(bias); } - // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { @@ -2363,27 +2489,35 @@ template if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2399,12 +2533,11 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2421,7 +2554,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2431,7 +2573,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 11cd8421b..de83cb657 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -136,6 +136,7 @@ instantiate_quantized_groups(2) \ instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ + instantiate_quantized_groups(5) \ instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 11a2355cc..b6dc8db30 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -976,7 +976,9 @@ void fast::AffineQuantize::eval_gpu( // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; constexpr int simd_size = 32; - int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_; + int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 + : bits_ == 6 ? 4 + : 8 / bits_; int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; size_t nthreads = dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 77210f713..c77b97de5 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -839,14 +839,14 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { if (group_size != 32 && group_size != 64 && group_size != 128) { std::ostringstream msg; msg << "[quantize] The requested group size " << group_size - << " is not supported. The supported group sizes are 64 and 128."; + << " is not supported. The supported group sizes are 32, 64, and 128."; throw std::invalid_argument(msg.str()); } - if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) { + if (bits < 2 || bits > 8 || bits == 7) { std::ostringstream msg; msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 6 and 8."; + << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 60ab421c6..3c4f03e4d 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 6, 4, 8]: + for b in [2, 3, 5, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) w_hat = mx.dequantize(w_q, scales, biases, gs, b) @@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase): # test quantize/dequantize 0s a = mx.zeros((256, 512)) for gs in [32, 64, 128]: - for b in [2, 3, 4, 6, 8]: + for b in [2, 3, 4, 5, 6, 8]: w_q, scales, biases = mx.quantize(a, gs, b) a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) @@ -146,7 +146,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 3, 4, 6, 8], # bits + [2, 3, 4, 5, 6, 8], # bits [256, 512, 67], # M [64, 128], # N [0, 1, 3, 8], # B @@ -173,7 +173,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 3, 4, 6, 8], # bits + [2, 3, 4, 5, 6, 8], # bits [32, 128, 256], # M [128, 256, 67], # N [0, 1, 3, 8], # B diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index ddfceb0a1..52f1a49ad 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + gc.collect() mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() @@ -653,6 +654,7 @@ class TestVmap(mlx_tests.MLXTestCase): outer() gc.collect() + mx.synchronize() if mx.metal.is_available(): mem_post = mx.get_active_memory() else: From db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 31 May 2025 04:12:54 +0900 Subject: [PATCH 062/156] Add memory cache to CUDA backend (#2221) * Move BufferCache out of allocator * Add memory cache to cuda backend allocator * Simplify BufferCache assuming buf can not be null --- mlx/backend/common/buffer_cache.h | 157 ++++++++++++++++++++++++++++++ mlx/backend/cuda/allocator.cpp | 115 +++++++++++++++------- mlx/backend/cuda/allocator.h | 8 ++ mlx/backend/metal/allocator.cpp | 142 ++------------------------- mlx/backend/metal/allocator.h | 40 +------- 5 files changed, 259 insertions(+), 203 deletions(-) create mode 100644 mlx/backend/common/buffer_cache.h diff --git a/mlx/backend/common/buffer_cache.h b/mlx/backend/common/buffer_cache.h new file mode 100644 index 000000000..92b20f222 --- /dev/null +++ b/mlx/backend/common/buffer_cache.h @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + auto it = buffer_pool_.lower_bound(size); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; + } + + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; + + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; + } + + void recycle_to_cache(T* buf) { + assert(buf); + // Add to cache. + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + free_(holder->buf); + n_release++; + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 203534e21..86af3a774 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -13,24 +14,47 @@ namespace mlx::core { namespace cu { -CudaAllocator::CudaAllocator() { +CudaAllocator::CudaAllocator() + : buffer_cache_( + getpagesize(), + [](CudaBuffer* buf) { return buf->size; }, + [this](CudaBuffer* buf) { cuda_free(buf); }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.8; + max_pool_size_ = memory_limit_; } Buffer CudaAllocator::malloc(size_t size) { - // TODO: Check memory limit. - auto* buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); - if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { - throw std::runtime_error( - fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + // Find available buffer from cache. + std::unique_lock lock(mutex_); + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache. + size_t mem_required = get_active_memory() + get_cache_memory() + size; + if (mem_required >= memory_limit_) { + buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + } + + lock.unlock(); + buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error(fmt::format( + "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + lock.lock(); } - std::lock_guard lock(mutex_); active_memory_ += size; peak_memory_ = std::max(active_memory_, peak_memory_); + + // Maintain the cache below the requested limit. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return Buffer{buf}; } @@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) { return; } - // If free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([buffer]() { allocator().free(buffer); }); - worker_->end_batch(); - worker_->commit(); - return; - } + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + lock.unlock(); + cuda_free(buf); } - - size_t size = buf->size; - cudaFree(buf->data); - delete buf; - std::lock_guard lock(mutex_); - active_memory_ -= size; } size_t CudaAllocator::size(Buffer buffer) const { @@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) { return limit; } +size_t CudaAllocator::get_cache_memory() const { + return buffer_cache_.cache_size(); +} + +size_t CudaAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + return limit; +} + +void CudaAllocator::clear_cache() { + std::lock_guard lk(mutex_); + buffer_cache_.clear(); +} + +void CudaAllocator::cuda_free(CudaBuffer* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + cudaFree(buf->data); + delete buf; +} + CudaAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This @@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) { size_t get_memory_limit() { return cu::allocator().get_memory_limit(); } - -// TODO: Implement buffer cache. size_t get_cache_memory() { - return 0; + return cu::allocator().get_cache_memory(); } -size_t set_cache_limit(size_t) { - return 0; +size_t set_cache_limit(size_t limit) { + return cu::allocator().set_cache_limit(limit); } +void clear_cache() { + cu::allocator().clear_cache(); +} + +// Not supported in CUDA. size_t set_wired_limit(size_t) { return 0; } -void clear_cache() {} } // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 6c418ee7e..fe3755121 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include #include @@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator { void reset_peak_memory(); size_t get_memory_limit(); size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); private: CudaAllocator(); friend CudaAllocator& allocator(); + void cuda_free(CudaBuffer* buf); + std::mutex worker_mutex_; std::unique_ptr worker_; std::set allowed_threads_; std::mutex mutex_; size_t memory_limit_; + size_t max_pool_size_; + BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; }; diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5d8bd90d5..dd6189732 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -30,141 +30,18 @@ void* Buffer::raw_ptr() { namespace metal { -namespace { - -BufferCache::BufferCache(ResidencySet& residency_set) - : head_(nullptr), - tail_(nullptr), - pool_size_(0), - residency_set_(residency_set) {} - -BufferCache::~BufferCache() { - auto pool = metal::new_scoped_memory_pool(); - clear(); -} - -int BufferCache::clear() { - int n_release = 0; - for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) { - if (!holder->buf->heap()) { - residency_set_.erase(holder->buf); - } - holder->buf->release(); - n_release++; - } - delete holder; - } - buffer_pool_.clear(); - pool_size_ = 0; - head_ = nullptr; - tail_ = nullptr; - return n_release; -} - -MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - // Find the closest buffer in pool - MTL::Buffer* pbuf = nullptr; - - auto it = buffer_pool_.lower_bound(size); - - // Make sure we use most of the available memory - while (!pbuf && it != buffer_pool_.end() && - it->first < std::min(2 * size, size + 2 * vm_page_size)) { - // Collect from the cache - pbuf = it->second->buf; - - // Remove from cache - remove_from_list(it->second); - delete it->second; - it = buffer_pool_.erase(it); - } - - if (pbuf) { - pool_size_ -= pbuf->length(); - } - - return pbuf; -} - -void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - // Add to cache - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - pool_size_ += buf->length(); - buffer_pool_.insert({buf->length(), bh}); - } -} - -int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - if (min_bytes_to_free >= 0.9 * pool_size_) { - return clear(); - } else { - int n_release = 0; - size_t total_bytes_freed = 0; - - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += tail_->buf->length(); - if (!tail_->buf->heap()) { - residency_set_.erase(tail_->buf); - } - tail_->buf->release(); - tail_->buf = nullptr; - n_release++; - } - remove_from_list(tail_); - } - pool_size_ -= total_bytes_freed; - return n_release; - } -} - -void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { - if (!to_add) - return; - - if (!head_) { - head_ = to_add; - tail_ = to_add; - } else { - head_->prev = to_add; - to_add->next = head_; - head_ = to_add; - } -} - -void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) { - return; - } - - // If in the middle - if (to_remove->prev && to_remove->next) { - to_remove->prev->next = to_remove->next; - to_remove->next->prev = to_remove->prev; - } else if (to_remove->prev && to_remove == tail_) { // If tail - tail_ = to_remove->prev; - tail_->next = nullptr; - } else if (to_remove == head_ && to_remove->next) { // If head - head_ = to_remove->next; - head_->prev = nullptr; - } else if (to_remove == head_ && to_remove == tail_) { // If only element - head_ = nullptr; - tail_ = nullptr; - } - - to_remove->prev = nullptr; - to_remove->next = nullptr; -} - -} // namespace - MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(residency_set_) { + buffer_cache_( + vm_page_size, + [](MTL::Buffer* buf) { return buf->length(); }, + [this](MTL::Buffer* buf) { + if (!buf->heap()) { + residency_set_.erase(buf); + } + buf->release(); + }) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() { if (heap_) { heap_->release(); } + buffer_cache_.clear(); } size_t MetalAllocator::set_cache_limit(size_t limit) { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 227b09e91..691317916 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -7,6 +7,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/resident.h" @@ -14,43 +15,6 @@ namespace mlx::core::metal { using allocator::Buffer; -namespace { - -class BufferCache { - public: - BufferCache(ResidencySet& residency_set); - ~BufferCache(); - - MTL::Buffer* reuse_from_cache(size_t size); - void recycle_to_cache(MTL::Buffer* buf); - int release_cached_buffers(size_t min_bytes_to_free); - size_t cache_size() { - return pool_size_; - } - int clear(); - - private: - struct BufferHolder { - public: - BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} - - BufferHolder* prev; - BufferHolder* next; - MTL::Buffer* buf; - }; - - void add_at_head(BufferHolder* to_add); - void remove_from_list(BufferHolder* to_remove); - - std::multimap buffer_pool_; - BufferHolder* head_; - BufferHolder* tail_; - size_t pool_size_; - ResidencySet& residency_set_; -}; - -} // namespace - class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: @@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator { friend MetalAllocator& allocator(); // Caching allocator - BufferCache buffer_cache_; + BufferCache buffer_cache_; ResidencySet residency_set_; From 95b7551d65c24d65e0a3de2a656f11577499fd9f Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 3 Jun 2025 05:23:34 +0900 Subject: [PATCH 063/156] Do not check event.is_signaled() in eval_impl (#2230) --- mlx/transforms.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 2d9942eda..3a02f39cb 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -208,9 +208,7 @@ array eval_impl(std::vector outputs, bool async) { // output arrays stream fences[it->second].wait(stream, in); } else if (in.event().valid()) { - if (in.event().is_signaled()) { - in.detach_event(); - } else if (in.event().stream() != stream) { + if (in.event().stream() != stream) { // Use event to wait across async eval in.event().wait(stream); } From 1b021f6984d24a59f39a71f0fe2540d825cc181f Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 3 Jun 2025 05:26:37 +0900 Subject: [PATCH 064/156] Fast primitives decide when to use the fallback (#2216) --- mlx/backend/cuda/primitives.cu | 23 ++++++++-- mlx/backend/metal/normalization.cpp | 8 ++++ mlx/backend/metal/rope.cpp | 4 ++ .../metal/scaled_dot_product_attention.cpp | 42 ++++++++++++++++++- mlx/backend/no_gpu/primitives.cpp | 23 ++++++++-- mlx/fast.cpp | 38 ++++------------- mlx/fast_primitives.h | 22 ++++++---- 7 files changed, 115 insertions(+), 45 deletions(-) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index d105a242b..11de02c8e 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { }); } +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + #define NO_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ } +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no CUDA implementation."); \ @@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_MULTI(LayerNorm) +NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) -NO_GPU_MULTI(RMSNorm) +NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) -NO_GPU_MULTI(RoPE) +NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 21142183e..c0901ccec 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -10,6 +10,10 @@ namespace mlx::core::fast { +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RMSNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu( } } +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void LayerNorm::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index d8201afe6..e141df630 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -7,6 +7,10 @@ namespace mlx::core::fast { constexpr int n_per_thread = 4; +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + void RoPE::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 3c7b7ff19..aad1a0018 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -4,10 +4,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" - #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::fast { @@ -339,6 +339,46 @@ void sdpa_vector_2pass( } // namespace +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (detail::in_grad_tracing()) { + return true; + } + if (s.device == Device::cpu) { + return true; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); + + const bool sdpa_vector_supported_head_dim = + query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); + const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); + + const bool supports_sdpa_full = + sdpa_full_supported_mask && sdpa_full_supported_head_dim; + + const bool supports_sdpa_vector = (query_sequence_length <= 8) && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_head_dim; + + return !(supports_sdpa_full || supports_sdpa_vector); +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 676a6e550..409aa2c89 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -10,6 +10,12 @@ throw std::runtime_error(#func " has no GPU implementation."); \ } +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + #define NO_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ throw std::runtime_error(#func " has no GPU implementation."); \ @@ -17,6 +23,17 @@ namespace mlx::core { +bool fast::ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + return true; +} + NO_GPU(Abs) NO_GPU(Add) NO_GPU(AddMM) @@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { -NO_GPU_MULTI(LayerNorm) +NO_GPU_USE_FALLBACK(LayerNorm) NO_GPU_MULTI(LayerNormVJP) -NO_GPU_MULTI(RMSNorm) +NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) -NO_GPU_MULTI(RoPE) +NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c77b97de5..eab22f14d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -9,7 +9,6 @@ #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" -#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -112,7 +111,8 @@ array rms_norm( auto passed_weight = (has_weight) ? astype(*weight, out_type, s) : array(1, out_type); - if (s.device == Device::gpu) { + + if (!RMSNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -256,7 +256,7 @@ array layer_norm( auto passed_bias = (has_bias) ? astype(*bias, out_type, s) : array(0, out_type); - if (s.device == Device::gpu) { + if (!LayerNorm::use_fallback(s)) { return array( x.shape(), out_type, @@ -470,7 +470,7 @@ array rope( } }; auto stream = to_stream(s); - if (stream.device == Device::gpu) { + if (!RoPE::use_fallback(stream)) { return array( x.shape(), x.dtype(), @@ -727,31 +727,6 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - const int value_head_dim = v.shape(-1); - const int query_head_dim = q.shape(-1); - const int query_sequence_length = q.shape(2); - const int key_sequence_length = k.shape(2); - - const bool sdpa_vector_supported_head_dim = - query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); - const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - - const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || - (query_sequence_length <= key_sequence_length && do_causal); - - const bool supports_sdpa_full = sdpa_full_supported_mask && - sdpa_full_supported_head_dim && stream.device == Device::gpu; - - const bool supports_sdpa_vector = (query_sequence_length <= 8) && - (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_head_dim && stream.device == Device::gpu; - - const bool implementation_supports_use_case = - supports_sdpa_full || supports_sdpa_vector; - std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type @@ -770,7 +745,8 @@ array scaled_dot_product_attention( mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } - if (!detail::in_grad_tracing() && implementation_supports_use_case) { + if (!ScaledDotProductAttention::use_fallback( + q, k, v, has_mask, has_arr_mask, do_causal, stream)) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), @@ -779,7 +755,7 @@ array scaled_dot_product_attention( stream, fallback, scale, do_causal), std::move(inputs)); } - return fallback(inputs)[0]; + return fallback(std::move(inputs))[0]; } bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4d9e505ee..51050ea50 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -43,6 +43,8 @@ class RMSNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream stream); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -65,7 +67,6 @@ class RMSNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -91,7 +92,6 @@ class RMSNormVJP : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -103,6 +103,8 @@ class LayerNorm : public Custom { float eps) : Custom(stream, fallback), eps_(eps) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -124,7 +126,6 @@ class LayerNorm : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -150,7 +151,6 @@ class LayerNormVJP : public Custom { } private: - std::function(std::vector)> fallback_; float eps_; }; @@ -171,6 +171,8 @@ class RoPE : public Custom { scale_(scale), forward_(forward) {} + static bool use_fallback(Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -193,7 +195,6 @@ class RoPE : public Custom { } private: - std::function(std::vector)> fallback_; int dims_; bool traditional_; float base_; @@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom { const bool do_causal) : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} + static bool use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s); + void eval_cpu(const std::vector& inputs, std::vector& outputs) override { throw std::runtime_error("NYI"); @@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom { } private: - std::function(std::vector)> fallback_; float scale_; bool do_causal_; }; @@ -263,7 +272,6 @@ class AffineQuantize : public Custom { } private: - std::function(std::vector)> fallback_; int group_size_; int bits_; bool dequantize_; From cbad6c3093d7207f878702b6eeda29816beafd87 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 2 Jun 2025 15:58:33 -0700 Subject: [PATCH 065/156] version (#2237) --- mlx/version.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/version.h b/mlx/version.h index c573c45c9..45ccdf3a7 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,8 +3,8 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 2 +#define MLX_VERSION_MINOR 26 +#define MLX_VERSION_PATCH 0 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From 0408ba0a768a3493fc3e12262162eca2e55346f0 Mon Sep 17 00:00:00 2001 From: Suryash Malviya <71389351+thesuryash@users.noreply.github.com> Date: Mon, 2 Jun 2025 18:58:46 -0400 Subject: [PATCH 066/156] =?UTF-8?q?Optimizing=20Complex=20Matrix=20Multipl?= =?UTF-8?q?ication=20using=20Karatsuba=E2=80=99s=20Algorithm=20=20(#2220)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implementing Complex Matmul using Karatsuba Algorithm * Implemented Karatsuba's Algorithm for complex matmul and pre-commit them * fix --------- Co-authored-by: Awni Hannun --- mlx/ops.cpp | 25 +++++++++++++++++-------- python/tests/test_blas.py | 14 +++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a72c2bc85..9602f667a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2862,21 +2862,30 @@ array matmul( << " second input with shape " << b.shape() << "."; throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = promote_types(a.dtype(), b.dtype()); - // Complex matmul in terms of real matmuls - if (out_type == complex64) { + + // complex matmul using Karatsuba's Algorithm + if (a.dtype() == complex64 || b.dtype() == complex64) { + // Extract real and imaginary parts auto a_real = real(a, s); - auto b_real = real(b, s); auto a_imag = imag(a, s); + auto b_real = real(b, s); auto b_imag = imag(b, s); - auto c_real = - subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); - auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); + + // Compute real and imaginary components of the result + auto m1 = matmul(a_real, b_real, s); + auto m2 = matmul(a_imag, b_imag, s); + auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s); + + auto c_real = subtract(m1, m2, s); + auto c_imag = subtract(m3, add(m1, m2, s), s); + return add( c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); } + // Type promotion + auto out_type = promote_types(a.dtype(), b.dtype()); + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[matmul] Only real floating point types are supported but " diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index df459eadc..8c7a97ba8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(c, c_np)) # Test addmm - M = 16 - K = 50 - N = 32 - - def rand(shape): - return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) - a = rand((M, K)) b = rand((K, N)) c = rand((M, N)) @@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase): out_np = 2.0 * np.matmul(a, b) + 2.0 * c self.assertTrue(np.allclose(out, out_np)) + # complex with real + a = rand((M, K)).real + b = rand((K, N)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(out, out_np)) + if __name__ == "__main__": unittest.main() From 5685ceb3c79618fcda983b2f3657bf9528c64220 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 4 Jun 2025 08:48:40 +0900 Subject: [PATCH 067/156] Avoid invoking allocator::malloc when creating CUDA event (#2232) --- mlx/backend/cuda/allocator.cpp | 47 ++++++++++++++++++---------------- mlx/backend/cuda/allocator.h | 5 ++-- mlx/backend/cuda/event.cu | 9 ++++--- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 86af3a774..00f78fd4f 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator() : buffer_cache_( getpagesize(), [](CudaBuffer* buf) { return buf->size; }, - [this](CudaBuffer* buf) { cuda_free(buf); }) { + [this](CudaBuffer* buf) { + cuda_free(buf->data); + delete buf; + }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); @@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) { buffer_cache_.recycle_to_cache(buf); } else { lock.unlock(); - cuda_free(buf); + cuda_free(buf->data); + delete buf; } } @@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() { allowed_threads_.insert(std::this_thread::get_id()); } +void CudaAllocator::cuda_free(void* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + cudaFree(buf); +} + size_t CudaAllocator::get_active_memory() const { return active_memory_; } @@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() { buffer_cache_.clear(); } -void CudaAllocator::cuda_free(CudaBuffer* buf) { - // If cuda_free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([this, buf]() { this->cuda_free(buf); }); - worker_->end_batch(); - worker_->commit(); - return; - } - } - - cudaFree(buf->data); - delete buf; -} - CudaAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index fe3755121..e268c6334 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator { // buffers there would result in dead lock. void register_this_thread(); + // Call cudaFree in the safe thread. + void cuda_free(void* buf); + size_t get_active_memory() const; size_t get_peak_memory() const; void reset_peak_memory(); @@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator { CudaAllocator(); friend CudaAllocator& allocator(); - void cuda_free(CudaBuffer* buf); - std::mutex worker_mutex_; std::unique_ptr worker_; std::set allowed_threads_; diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index a487f45b4..f462720a9 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/utils.h" @@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { SharedEvent::SharedEvent() { // Allocate cuda::atomic on managed memory. - allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); - Atomic* ac = static_cast(buffer.raw_ptr()); + Atomic* ac; + CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); new (ac) Atomic(0); - ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ac_ = std::shared_ptr(ac, [](Atomic* ptr) { ptr->~Atomic(); - allocator::free(buffer); + allocator().cuda_free(ptr); }); } From 0bb89e9e5fec1ff2154112cfbdf46ba0a22aa0f6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 4 Jun 2025 08:48:50 +0900 Subject: [PATCH 068/156] Share more common code in Compiled (#2240) * Share more common code in Compiled * Remove build_lib_name --- mlx/backend/common/compiled.cpp | 130 ++++++++++++++++++-------------- mlx/backend/common/compiled.h | 24 +++--- mlx/backend/cpu/compiled.cpp | 104 ++++++++----------------- mlx/backend/metal/compiled.cpp | 114 +++++----------------------- mlx/compile.cpp | 53 ++++++++++++- mlx/primitives.h | 1 + 6 files changed, 193 insertions(+), 233 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index f7b5598ab..98c48cca9 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -1,8 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/common/compiled.h" -#include "mlx/graph_utils.h" -#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" #include "mlx/utils.h" namespace mlx::core { @@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) { } } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids) { - NodeNamer namer; - std::ostringstream os; - std::ostringstream constant_hasher; - - // Fill the input names. This is not really necessary, I just like having A, - // B, C, ... as the inputs. - for (auto& x : inputs) { - namer.get_name(x); - } - - // The primitives describing the tape. For unary and binary primitives this - // must be enough to describe the full computation. - for (auto& a : tape) { - // name and type of output - os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); - // computation performed - a.primitive().print(os); - // name of inputs to the function - for (auto& inp : a.inputs()) { - os << namer.get_name(inp); - } - } - os << "_"; - - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - os << "C"; - print_constant(constant_hasher, x); - } else { - os << (is_scalar(x) ? "S" : "V"); - } - } - os << "_"; - for (auto& x : inputs) { - if (constant_ids.find(x.id()) != constant_ids.end()) { - continue; - } - os << kindof(x.dtype()) << x.itemsize(); - } - os << "_" << std::hash{}(constant_hasher.str()); - - return os.str(); -} - bool compiled_check_contiguity( const std::vector& inputs, const Shape& shape) { @@ -159,8 +109,7 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, bool contiguous) { if (contiguous) { int o = 0; @@ -175,8 +124,7 @@ void compiled_allocate_outputs( // - Donatable // - Not a constant if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && - in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + in.is_donatable() && is_constant(i)) { outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs @@ -204,7 +152,7 @@ void compiled_allocate_outputs( // - Not a constant if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && - constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { + is_constant(i)) { outputs[o].copy_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size()); o++; @@ -216,4 +164,74 @@ void compiled_allocate_outputs( } } +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant) { + const Shape& shape = out.shape(); + bool contiguous = compiled_check_contiguity(inputs, shape); + if (contiguous) { + return {true, shape, {}}; + } + + std::vector strides_vec{out.strides()}; + for (size_t i = 0; i < inputs.size(); ++i) { + // Skip constants. + if (is_constant(i)) { + continue; + } + + // Skip scalar inputs. + const auto& x = inputs[i]; + if (is_scalar(x)) { + continue; + } + + // Broadcast the inputs to the output shape. + Strides xstrides; + size_t j = 0; + for (; j < shape.size() - x.ndim(); ++j) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } + for (size_t i = 0; i < x.ndim(); ++i, ++j) { + if (x.shape(i) == 1) { + if (shape[j] == 1) { + xstrides.push_back(out.strides()[j]); + } else { + xstrides.push_back(0); + } + } else { + xstrides.push_back(x.strides()[i]); + } + } + strides_vec.push_back(std::move(xstrides)); + } + + auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX); + return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))}; +} + +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, + bool contiguous) { + if (contiguous) { + size_t max_size = 0; + for (const auto& in : inputs) { + max_size = std::max(max_size, in.data_size()); + } + return max_size > UINT32_MAX; + } else { + size_t max_size = 0; + for (const auto& o : outputs) { + max_size = std::max(max_size, o.size()); + } + return max_size > UINT32_MAX; + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index f4d28d6ab..6fccaacd6 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -1,9 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #pragma once +#include #include -#include -#include #include "mlx/array.h" #include "mlx/primitives.h" @@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) { return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } -std::string build_lib_name( - const std::vector& inputs, - const std::vector& outputs, - const std::vector& tape, - const std::unordered_set& constant_ids); - std::string get_type_string(Dtype d); template @@ -60,8 +53,19 @@ bool compiled_check_contiguity( void compiled_allocate_outputs( const std::vector& inputs, std::vector& outputs, - const std::vector& inputs_, - const std::unordered_set& constant_ids_, + const std::function& is_constant, + bool contiguous); + +// Collapse contiguous dims ignoring scalars and constants. +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant); + +// Return whether the kernel should use large index. +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, bool contiguous); } // namespace mlx::core diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index e389e0df5..d0bfb4f45 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -146,18 +146,9 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; #ifdef _MSC_VER @@ -170,14 +161,15 @@ inline void build_kernel( // Add the input arguments int cnt = 0; - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + auto tstr = get_type_string(x.dtype()); os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; @@ -211,10 +203,11 @@ inline void build_kernel( } // Read the inputs in tmps - for (auto& x : inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; print_constant(os, x); os << ";" << std::endl; @@ -264,8 +257,9 @@ inline void build_kernel( } else { for (int d = ndim - 1; d >= 0; --d) { // Update pointers - for (auto& x : inputs) { - if (is_constant(x) || is_scalar(x)) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + if (is_constant(i) || is_scalar(x)) { continue; } auto& xname = namer.get_name(x); @@ -287,65 +281,37 @@ inline void build_kernel( void Compiled::eval_cpu( const std::vector& inputs, std::vector& outputs) { - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - - // Figure out which kernel we are using - auto& shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, shape); auto& encoder = cpu::get_command_encoder(stream()); - // Handle all broadcasting and collect function input arguments + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Collect function input arguments. std::vector args; - std::vector> strides; - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { continue; } - auto& x = inputs[i]; + const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); - - if (contiguous || is_scalar(x)) { - continue; + if (!contiguous && !is_scalar(x)) { + args.push_back(strides[strides_index++].data()); } - - // Broadcast the input to the output shape. - std::vector xstrides; - int j = 0; - for (; j < shape.size() - x.ndim(); j++) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - strides.push_back(std::move(xstrides)); - args.push_back(strides.back().data()); } // Get the kernel name from the lib int ndim = shape.size(); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); if (!contiguous) { - kernel_name += std::to_string(shape.size()); + kernel_name += std::to_string(ndim); } // Get the function - auto fn_ptr = compile(kernel_name, [&]() { + auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { std::ostringstream kernel; kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; @@ -355,7 +321,7 @@ void Compiled::eval_cpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, contiguous, ndim); // Close extern "C" @@ -363,26 +329,22 @@ void Compiled::eval_cpu( return kernel.str(); }); - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { args.push_back(x.data()); encoder.set_output_array(x); } - Shape out_shape; if (!contiguous) { - out_shape = outputs[0].shape(); - args.push_back((void*)out_shape.data()); + args.push_back((void*)shape.data()); } else { args.push_back((void*)outputs[0].data_size()); } auto fun = (void (*)(void**))fn_ptr; - encoder.dispatch( - [fun, - args = std::move(args), - strides = std::move(strides), - out_shape = std::move(out_shape)]() mutable { fun(args.data()); }); + encoder.dispatch([fun, + args = std::move(args), + strides = std::move(strides), + shape = std::move(shape)]() mutable { fun(args.data()); }); } } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index db20f938c..6a67b4f57 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -11,8 +11,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -using namespace fmt::literals; - namespace mlx::core { inline void build_kernel( @@ -21,21 +19,12 @@ inline void build_kernel( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, - const std::unordered_set& constant_ids, + const std::function& is_constant, bool contiguous, int ndim, bool dynamic_dims, bool use_big_index = false, int work_per_thread = 1) { - // All outputs should have the exact same shape and will be row contiguous - auto output_shape = outputs[0].shape(); - auto output_strides = outputs[0].strides(); - - // Constants are scalars that are captured by value and cannot change - auto is_constant = [&constant_ids](const array& x) { - return constant_ids.find(x.id()) != constant_ids.end(); - }; - NodeNamer namer; bool add_indices = false; int cnt = 0; @@ -45,14 +34,15 @@ inline void build_kernel( "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); // Add the input arguments - for (auto& x : inputs) { - auto& xname = namer.get_name(x); - + for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list - if (is_constant(x)) { + if (is_constant(i)) { continue; } + const auto& x = inputs[i]; + auto& xname = namer.get_name(x); + // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { add_indices = true; @@ -80,8 +70,6 @@ inline void build_kernel( } // Add output strides and shape to extract the indices. if (!contiguous) { - os += fmt::format( - " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); } else { @@ -125,7 +113,7 @@ inline void build_kernel( auto& x = inputs[i]; auto& xname = namer.get_name(x); - if (is_constant(x)) { + if (is_constant(i)) { auto type_str = get_type_string(x.dtype()); std::ostringstream ss; print_constant(ss, x); @@ -271,11 +259,6 @@ inline void build_kernel( void Compiled::eval_gpu( const std::vector& inputs, std::vector& outputs) { - // Make the name for the kernel library - if (kernel_lib_.empty()) { - kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_); - } - // Get the kernel if someone else built it already auto& s = stream(); auto& d = metal::device(s.device); @@ -290,7 +273,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, @@ -302,7 +285,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, @@ -315,7 +298,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -328,7 +311,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ i, /* dynamic_dims = */ false, @@ -342,7 +325,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -354,7 +337,7 @@ void Compiled::eval_gpu( inputs_, outputs_, tape_, - constant_ids_, + is_constant_, /* contiguous = */ false, /* ndim = */ 0, /* dynamic_dims = */ true, @@ -363,70 +346,13 @@ void Compiled::eval_gpu( return kernel; }); - // Figure out which kernel we are using - auto& output_shape = outputs[0].shape(); - auto contiguous = compiled_check_contiguity(inputs, output_shape); - // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. - std::vector initial_strides; - initial_strides.push_back(outputs[0].strides()); - Shape shape; - std::vector strides; - if (!contiguous) { - for (int i = 0; i < inputs.size(); i++) { - // Skip constants. - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { - continue; - } - auto& x = inputs[i]; + auto [contiguous, shape, strides] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Skip scalar inputs. - if (is_scalar(x)) { - continue; - } - - // Broadcast the inputs to the output shape. - Strides xstrides; - int j = 0; - for (; j < output_shape.size() - x.ndim(); j++) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } - for (int i = 0; i < x.ndim(); i++, j++) { - if (x.shape(i) == 1) { - if (output_shape[j] == 1) { - xstrides.push_back(outputs[0].strides()[j]); - } else { - xstrides.push_back(0); - } - } else { - xstrides.push_back(x.strides()[i]); - } - } - initial_strides.push_back(std::move(xstrides)); - } - std::tie(shape, strides) = - collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); - } - - bool large; - if (contiguous) { - size_t max_size = 0; - for (auto& in : inputs) { - max_size = std::max(max_size, in.data_size()); - } - large = (max_size > UINT32_MAX); - } else { - size_t max_size = 0; - for (auto& o : outputs) { - max_size = std::max(max_size, o.size()); - } - large = (max_size > UINT32_MAX); - } + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); // Get the kernel from the lib int ndim = shape.size(); @@ -451,7 +377,7 @@ void Compiled::eval_gpu( int stride_idx = 1; // idx 0 is the output strides Strides in_strides; for (int i = 0; i < inputs.size(); i++) { - if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { + if (is_constant_(i)) { continue; } auto& x = inputs[i]; @@ -468,8 +394,7 @@ void Compiled::eval_gpu( compute_encoder.set_vector_bytes(in_strides, cnt++); } - compiled_allocate_outputs( - inputs, outputs, inputs_, constant_ids_, contiguous); + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); // Put the outputs in for (auto& x : outputs) { @@ -478,7 +403,6 @@ void Compiled::eval_gpu( // Put the output shape and strides in if (!contiguous) { - compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); } else { auto size = outputs[0].data_size(); diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 2baeb6fcf..79a55ba8f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,16 +1,20 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include #include "mlx/allocator.h" +#include "mlx/backend/common/compiled.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "mlx/utils.h" namespace mlx::core { @@ -82,7 +86,54 @@ Compiled::Compiled( inputs_(std::move(inputs)), outputs_(std::move(outputs)), tape_(std::move(tape)), - constant_ids_(std::move(constant_ids)) {} + constant_ids_(std::move(constant_ids)), + is_constant_([this](size_t i) { + return constant_ids_.find(inputs_[i].id()) != constant_ids_.end(); + }) { + // Build the kernel name. + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + // Fill the input names. This is not really necessary, I just like having A, + // B, C, ... as the inputs. + for (const auto& x : inputs_) { + namer.get_name(x); + } + + // The primitives describing the tape. For unary and binary primitives this + // must be enough to describe the full computation. + for (const auto& a : tape_) { + // name and type of output + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + // computation performed + a.primitive().print(os); + // name of inputs to the function + for (auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + os << "_"; + + for (const auto& x : inputs_) { + if (constant_ids_.find(x.id()) != constant_ids_.end()) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + os << "_"; + for (const auto& x : inputs) { + if (constant_ids.find(x.id()) != constant_ids.end()) { + continue; + } + os << kindof(x.dtype()) << x.itemsize(); + } + os << "_" << std::hash{}(constant_hasher.str()); + + kernel_lib_ = os.str(); +} std::vector Compiled::vjp( const std::vector&, diff --git a/mlx/primitives.h b/mlx/primitives.h index c0fbfc84d..cc60bcfb9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -627,6 +627,7 @@ class Compiled : public Primitive { const std::vector outputs_; const std::vector tape_; const std::unordered_set constant_ids_; + const std::function is_constant_; std::string kernel_lib_; }; From 85a8beb5e4f7d1426f79dc5e55953795c6039fe8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 4 Jun 2025 08:49:06 +0900 Subject: [PATCH 069/156] Avoid atomic updates across CPU/GPU in CUDA event (#2231) --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/event.cu | 5 ++- mlx/backend/cuda/fence.cpp | 29 ++++++++++++++ mlx/backend/cuda/fence.cu | 70 --------------------------------- 4 files changed, 34 insertions(+), 72 deletions(-) create mode 100644 mlx/backend/cuda/fence.cpp delete mode 100644 mlx/backend/cuda/fence.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 2a8ef9963..8c9a40d03 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -10,7 +10,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu - ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index f462720a9..9fc5c641b 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -156,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) { nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + // Signal through a GPU stream so the atomic is updated in GPU - updating + // the atomic in CPU sometimes does not get GPU notified. + static CudaStream stream(device(mlx::core::Device::gpu)); + scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); } else { auto& encoder = get_command_encoder(s); encoder.launch_kernel( diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp new file mode 100644 index 000000000..f399c4ebb --- /dev/null +++ b/mlx/backend/cuda/fence.cpp @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/fence.h" +#include "mlx/backend/cuda/event.h" + +namespace mlx::core { + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->event.wait(fence->count); +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu deleted file mode 100644 index 091b252c1..000000000 --- a/mlx/backend/cuda/fence.cu +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/event.h" -#include "mlx/fence.h" -#include "mlx/scheduler.h" - -#include - -namespace mlx::core { - -namespace { - -__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { - while (true) { - // In theory the atomic_thread_fence is not needed, but for CUDA 11 without - // it the load() may never return new value. - cuda::atomic_thread_fence(cuda::memory_order_seq_cst); - uint64_t current = ac->load(); - if (current >= value) { - break; - } - } -} - -__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { - busy_wait(ac, value); -} - -} // namespace - -struct FenceImpl { - uint32_t count; - cu::SharedEvent event; -}; - -Fence::Fence(Stream s) { - fence_ = std::shared_ptr( - new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); -} - -void Fence::wait(Stream s, const array&) { - auto* fence = static_cast(fence_.get()); - // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: - // https://github.com/ml-explore/mlx/issues/2137 - const auto& ac = fence->event.atomic(); - if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [ac, count = fence->count]() { - nvtx3::scoped_range r("Fence::wait()"); - busy_wait(ac.get(), count); - }); - } else { - nvtx3::scoped_range r("Fence::wait(s)"); - auto& encoder = cu::get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { - busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); - }); - encoder.add_completed_handler([ac]() {}); - encoder.end_encoding(); - } -} - -void Fence::update(Stream s, const array&) { - auto* fence = static_cast(fence_.get()); - fence->count++; - fence->event.signal(s, fence->count); -} - -} // namespace mlx::core From aede70e81d02c4de8c593c6dcf82591131c29677 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 3 Jun 2025 17:55:12 -0700 Subject: [PATCH 070/156] Perf regression fix (#2243) --- mlx/transforms.cpp | 4 +++- mlx/version.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3a02f39cb..2d9942eda 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -208,7 +208,9 @@ array eval_impl(std::vector outputs, bool async) { // output arrays stream fences[it->second].wait(stream, in); } else if (in.event().valid()) { - if (in.event().stream() != stream) { + if (in.event().is_signaled()) { + in.detach_event(); + } else if (in.event().stream() != stream) { // Use event to wait across async eval in.event().wait(stream); } diff --git a/mlx/version.h b/mlx/version.h index 45ccdf3a7..530d0620d 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 26 -#define MLX_VERSION_PATCH 0 +#define MLX_VERSION_PATCH 1 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From 52dc8c8cd58cd55b21c8e33486b6516061ab3f61 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 5 Jun 2025 11:55:12 +0900 Subject: [PATCH 071/156] Add profiler annotations in common primitives for CUDA backend (#2244) --- mlx/backend/cuda/CMakeLists.txt | 2 ++ mlx/backend/gpu/primitives.cpp | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8c9a40d03..c991c2094 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -17,6 +17,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) + # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index cd9296075..938923977 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -5,9 +5,17 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#if defined(MLX_USE_CUDA) +#include +#endif + #include +#if defined(MLX_USE_CUDA) +#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message) +#else #define MLX_PROFILER_RANGE(message) +#endif namespace mlx::core { From c763fe1be0f1158e0f53f6f6a28f56d69b1c7fe8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 5 Jun 2025 15:27:02 -0700 Subject: [PATCH 072/156] default strict mode for module update and update_modules (#2239) --- python/mlx/nn/layers/base.py | 50 +++++++++++++++++++++++++++--------- python/tests/test_nn.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index b35c58478..783ef446d 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -193,7 +193,7 @@ class Module(dict): ) if len(weights) != 0: - self.update(tree_unflatten(weights)) + self.update(tree_unflatten(weights), strict=False) return self def save_weights(self, file: str): @@ -291,7 +291,7 @@ class Module(dict): return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) - def update(self, parameters: dict) -> Module: + def update(self, parameters: dict, strict: bool = True) -> Module: """Replace the parameters of this Module with the provided ones in the dict of dicts and lists. @@ -305,7 +305,9 @@ class Module(dict): Args: parameters (dict): A complete or partial dictionary of the modules - parameters. + parameters. + strict (bool): If ``True`` checks that ``parameters`` is a + subset of the module's parameters. Default: ``True``. Returns: The module instance after updating the parameters. """ @@ -317,21 +319,29 @@ class Module(dict): current_value = dst[k] new_value = parameters[k] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[k] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f'Module does not have parameter named "{k}".') elif isinstance(parameters, list): for i in range(len(parameters)): current_value = dst[i] new_value = parameters[i] if isinstance(current_value, mx.array): + if strict and not isinstance(new_value, mx.array): + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) dst[i] = new_value - elif isinstance(current_value, Module): - current_value.update(new_value) - elif isinstance(current_value, (dict, list)): + else: apply(current_value, new_value) + elif strict: + raise ValueError(f"Received invalid type: {type(parameters).__name__}.") apply(self, parameters) return self @@ -359,7 +369,7 @@ class Module(dict): self.update(self.filter_and_map(filter_fn, map_fn)) return self - def update_modules(self, modules: dict) -> Module: + def update_modules(self, modules: dict, strict: bool = True) -> Module: """Replace the child modules of this :class:`Module` instance with the provided ones in the dict of dicts and lists. @@ -368,12 +378,14 @@ class Module(dict): programmatically swapping layers. The passed in parameters dictionary need not be a full dictionary - similar to :meth:`parameters`. Only the provided locations will be + similar to :meth:`modules`. Only the provided locations will be updated. Args: - modules (dict): A complete or partial dictionary of the modules + modules (dict): A complete or partial dictionary of the module's submodules. + strict (bool): If ``True`` checks that ``modules`` is a + subset of the child modules of this instance. Default: ``True``. Returns: The module instance after updating the submodules. """ @@ -388,6 +400,14 @@ class Module(dict): dst[k] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError( + f'Module does not have sub-module named "{k}".' + ) elif isinstance(modules, list): for i in range(len(dst)): current_value = dst[i] @@ -396,6 +416,12 @@ class Module(dict): dst[i] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) + elif strict: + raise ValueError( + f"Received invalid type: {type(new_value).__name__}." + ) + elif strict: + raise ValueError(f"Received invalid type: {type(modules).__name__}.") apply(self, modules) return self diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 826d53d96..13e31ad96 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase): x = mx.zeros((3,)) mx.grad(loss_fn)(model) + def test_update(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent parameters + with self.assertRaises(ValueError): + updates = {"layers": [{"value": 0}]} + m.update(updates) + + with self.assertRaises(ValueError): + updates = {"layers": ["hello"]} + m.update(updates) + + # Wronge type + with self.assertRaises(ValueError): + updates = {"layers": [{"weight": "hi"}]} + m.update(updates) + + def test_update_modules(self): + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + + # Updating non-existent modules should not be allowed by default + with self.assertRaises(ValueError): + m = m.update_modules({"values": [0, 1]}) + + # Update wrong types + with self.assertRaises(ValueError): + m = m.update_modules({"layers": [0, 1]}) + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.test = mx.array(1.0) + self.list = [mx.array(1.0), mx.array(2.0)] + + m = MyModule() + with self.assertRaises(ValueError): + m = m.update_modules({"test": "hi"}) + with self.assertRaises(ValueError): + m = m.update_modules({"list": ["hi"]}) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): From a5ac9244c4bd71774b1ca9bc222cd7031965de37 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 10:41:51 -0700 Subject: [PATCH 073/156] fix linux linking error (#2248) --- python/src/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 7ea302cf9..29beca859 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) if(BUILD_SHARED_LIBS) - target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) + else() + target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib) + endif() endif() From c6a20b427ac624f4b9ac9d6f8e4c5847f1dbb672 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 11:37:40 -0700 Subject: [PATCH 074/156] Improve metal elementwise kernels (#2247) * improve metal elementwise kernels * compile and copy * fix jit --- mlx/backend/metal/binary.cpp | 15 ++- mlx/backend/metal/compiled.cpp | 26 +++++- mlx/backend/metal/copy.cpp | 15 +-- mlx/backend/metal/jit_kernels.cpp | 52 ++++++++--- mlx/backend/metal/kernels/binary.h | 60 +++++++++--- mlx/backend/metal/kernels/binary.metal | 47 ++++++---- mlx/backend/metal/kernels/binary_two.h | 102 +++++++++++++++------ mlx/backend/metal/kernels/binary_two.metal | 39 +++++--- mlx/backend/metal/kernels/copy.h | 52 ++++++++--- mlx/backend/metal/kernels/copy.metal | 20 ++-- mlx/backend/metal/kernels/ternary.h | 22 ++++- mlx/backend/metal/kernels/ternary.metal | 14 ++- mlx/backend/metal/kernels/unary.h | 22 ++++- mlx/backend/metal/kernels/unary.metal | 88 ++++++++++-------- mlx/backend/metal/ternary.cpp | 4 +- mlx/backend/metal/unary.cpp | 4 +- mlx/backend/metal/utils.h | 4 + 17 files changed, 412 insertions(+), 174 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index c3c67e4d5..54aaf153c 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -31,13 +31,13 @@ std::string get_kernel_name( kname = "ss"; break; case BinaryOpType::ScalarVector: - kname = (large ? "sv2" : "sv"); + kname = "sv"; break; case BinaryOpType::VectorScalar: - kname = (large ? "vs2" : "vs"); + kname = "vs"; break; case BinaryOpType::VectorVector: - kname = (large ? "vv2" : "vv"); + kname = "vv"; break; case BinaryOpType::General: kname = "g"; @@ -51,6 +51,13 @@ std::string get_kernel_name( } break; } + if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) { + if (large) { + kname += "2"; + } else if (work_per_thread > 1) { + kname += "n"; + } + } concatenate(kname, "_", op, type_to_name(a)); return kname; } @@ -90,7 +97,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = get_work_per_thread(a.dtype()); + work_per_thread = get_work_per_thread(a.dtype(), out.data_size()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 6a67b4f57..88edc6baa 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -278,7 +278,21 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, - /* work_per_thread = */ work_per_thread); + /* work_per_thread = */ 1); + if (work_per_thread > 1) { + build_kernel( + kernel, + kernel_lib_ + "_contiguous_n", + inputs_, + outputs_, + tape_, + is_constant_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); + } build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -358,12 +372,20 @@ void Compiled::eval_gpu( int ndim = shape.size(); bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + int work_per_thread = 1; if (!contiguous) { if (dynamic) { kernel_name += "dynamic"; } else { kernel_name += std::to_string(shape.size()); } + work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; + } else { + work_per_thread = + get_work_per_thread(outputs[0].dtype(), outputs[0].data_size()); + if (work_per_thread > 1 && !large) { + kernel_name += "_n"; + } } if (large) { kernel_name += "_large"; @@ -420,7 +442,6 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - int work_per_thread = get_work_per_thread(outputs[0].dtype()); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); @@ -433,7 +454,6 @@ void Compiled::eval_gpu( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); - int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 8dfe15c11..8123b793e 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -55,10 +55,10 @@ void copy_gpu_inplace( std::string kernel_name; switch (ctype) { case CopyType::Scalar: - kernel_name = (large ? "s2" : "s"); + kernel_name = large ? "s2" : "s"; break; case CopyType::Vector: - kernel_name = (large ? "v2" : "v"); + kernel_name = large ? "v2" : "v"; break; case CopyType::General: kernel_name = "g"; @@ -85,7 +85,10 @@ void copy_gpu_inplace( } } } else { - work_per_thread = get_work_per_thread(in.dtype()); + work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); + if (work_per_thread > 1) { + kernel_name += "n"; + } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) { } out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; + int work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); auto& d = metal::device(s.device); - std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + - type_to_name(val) + type_to_name(out); + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); + concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); - int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5206c9b54..15e21af6c 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel( std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += - get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); + get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); + if (get_work_per_thread(in_type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); + } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( @@ -59,11 +63,8 @@ void append_binary_kernels( Dtype out_type, const std::string op, std::string& kernel_source) { - const std::array, 10> kernel_types = {{ + const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, - {"vs", "binary_vs"}, - {"sv", "binary_sv"}, - {"vv", "binary_vv"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, @@ -78,6 +79,22 @@ void append_binary_kernels( kernel_source += get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } + kernel_source += get_template_definition( + "vs_" + lib_name, "binary_vs", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); + kernel_source += get_template_definition( + "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); + + if (get_work_per_thread(in_type) > 1) { + kernel_source += get_template_definition( + "vsn_" + lib_name, "binary_vs", in_t, out_t, op); + kernel_source += get_template_definition( + "svn_" + lib_name, "binary_sv", in_t, out_t, op); + kernel_source += get_template_definition( + "vvn_" + lib_name, "binary_vv", in_t, out_t, op); + } + kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( @@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel( auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); - const std::array, 5> kernel_types = {{ - {"v", "ternary_v"}, + const std::array, 4> kernel_types = {{ {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, @@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + if (get_work_per_thread(type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); + } + + kernel_source += + get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( @@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - kernel_source += - get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); - kernel_source += - get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + kernel_source += get_template_definition( + "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); + if (get_work_per_thread(out.dtype()) > 1) { + kernel_source += get_template_definition( + "sn_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "vn_" + lib_name, "copy_v", in_type, out_type); + } + kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index ffc33ad82..f1df88535 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -17,8 +17,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[0], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } } @@ -30,8 +36,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } } @@ -43,8 +55,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - c[index + i] = Op()(a[index + i], b[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } } @@ -57,8 +75,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[0], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } } @@ -71,8 +95,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[0]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } } @@ -85,8 +115,14 @@ template ::n> uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[offset + i]); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 1d555fefa..17ed13c57 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,11 +9,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \ + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -26,15 +31,19 @@ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) -#define instantiate_binary_integer(op) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + +#define instantiate_binary_integer(op) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ @@ -44,7 +53,7 @@ #define instantiate_binary_types(op) \ instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_integer(op) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t)\ instantiate_binary_float(op) #define instantiate_binary_types_bool(op) \ @@ -52,15 +61,15 @@ instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \ - instantiate_binary_all(op, uint64, uint64_t, bool) \ + instantiate_binary_base(op, uint64, uint64_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \ - instantiate_binary_all(op, int64, int64_t, bool) \ + instantiate_binary_base(op, int64, int64_t, bool) \ instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ - instantiate_binary_all(op, complex64, complex64_t, bool) + instantiate_binary_base(op, complex64, complex64_t, bool) instantiate_binary_types(Add) instantiate_binary_types(Divide) @@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less) instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(NotEqual) instantiate_binary_float(LogAddExp) -instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) +instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_types(Maximum) instantiate_binary_types(Minimum) instantiate_binary_types(Multiply) @@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2) instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) -instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) +instantiate_binary_base(NaNEqual, complex64, complex64_t, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool) diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index e261d33c4..4455e4ca9 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -21,10 +21,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[0], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -37,10 +45,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[0]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -53,10 +69,18 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - auto out = Op()(a[index + i], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } } @@ -69,11 +93,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[0], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -86,11 +118,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[0]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } @@ -103,11 +143,19 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - auto out = Op()(a[offset + i], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } } diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 984a28320..c7d3ecdf0 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,11 +7,16 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ +#define instantiate_binary_work_per_thread(op, tname, itype, otype) \ + instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) + +#define instantiate_binary_base(op, tname, itype, otype) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ @@ -24,22 +29,26 @@ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_binary_base(op, tname, itype, otype) \ + instantiate_binary_work_per_thread(op, tname, itype, otype) + #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_types(op) \ - instantiate_binary_all(op, bool_, bool, bool) \ - instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ - instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ - instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ - instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ - instantiate_binary_all(op, int8, int8_t, int8_t) \ - instantiate_binary_all(op, int16, int16_t, int16_t) \ - instantiate_binary_all(op, int32, int32_t, int32_t) \ - instantiate_binary_all(op, int64, int64_t, int64_t) \ - instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_base(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_base(op, int64, int64_t, int64_t) \ + instantiate_binary_base(op, complex64, complex64_t, complex64_t) \ instantiate_binary_float(op) instantiate_binary_types(DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 2469d1f3d..cf22347ee 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,52 +1,76 @@ // Copyright © 2024 Apple Inc. -template ::n> +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } } -template ::n> +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index bbf268158..fcf8884f8 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,9 +4,13 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ +#define instantiate_copy_work_per_thread(tname, itype, otype) \ + instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("vn_copy" #tname, copy_v, itype, otype) + +#define instantiate_copy_base(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ @@ -18,6 +22,10 @@ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy_base(tname, itype, otype) \ + instantiate_copy_work_per_thread(tname, itype, otype) + #define instantiate_copy_same(tname, type) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ @@ -42,15 +50,15 @@ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \ - instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_base(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \ - instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_base(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ - instantiate_copy_all(itname ##complex64, itype, complex64_t) + instantiate_copy_base(itname ##complex64, itype, complex64_t) instantiate_copy_itype(bool_, bool) instantiate_copy_itype(uint8, uint8_t) diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 5251dc7e9..570f5e4d6 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -9,8 +9,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } } @@ -23,9 +29,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index cceb53061..6da258b6f 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,8 +8,8 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ +#define instantiate_ternary_base(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ @@ -20,19 +20,23 @@ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \ + instantiate_ternary_base(op, tname, type) + #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint32, uint32_t) \ - instantiate_ternary_all(op, uint64, uint64_t) \ + instantiate_ternary_base(op, uint64, uint64_t) \ instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int32, int32_t) \ - instantiate_ternary_all(op, int64, int64_t) \ + instantiate_ternary_base(op, int64, int64_t) \ instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \ - instantiate_ternary_all(op, complex64, complex64_t) // clang-format on + instantiate_ternary_base(op, complex64, complex64_t) // clang-format on instantiate_ternary_types(Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index b5eaab2e9..649ba7f2c 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -7,8 +7,14 @@ template ::n> constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - out[index + i] = Op()(in[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } } } @@ -19,9 +25,15 @@ template ::n> constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - out[offset + i] = Op()(in[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index afced7eb7..160ef4af1 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,31 +5,41 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ - instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ - instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ - instantiate_kernel( \ - "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ - instantiate_kernel( \ +#define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) + +#define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel( \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ + instantiate_kernel( \ "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \ + instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) + #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) +#define instantiate_unary_base_same(op, tname, type) \ + instantiate_unary_base(op, tname, tname, type, type) + #define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_int(op) \ - instantiate_unary_all_same(op, uint8, uint8_t) \ - instantiate_unary_all_same(op, uint16, uint16_t) \ - instantiate_unary_all_same(op, uint32, uint32_t) \ - instantiate_unary_all_same(op, uint64, uint64_t) \ - instantiate_unary_all_same(op, int8, int8_t) \ - instantiate_unary_all_same(op, int16, int16_t) \ - instantiate_unary_all_same(op, int32, int32_t) \ - instantiate_unary_all_same(op, int64, int64_t) +#define instantiate_unary_int(op) \ + instantiate_unary_all_same(op, uint8, uint8_t) \ + instantiate_unary_all_same(op, uint16, uint16_t) \ + instantiate_unary_all_same(op, uint32, uint32_t) \ + instantiate_unary_base_same(op, uint64, uint64_t) \ + instantiate_unary_all_same(op, int8, int8_t) \ + instantiate_unary_all_same(op, int16, int16_t) \ + instantiate_unary_all_same(op, int32, int32_t) \ + instantiate_unary_base_same(op, int64, int64_t) #define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ @@ -68,29 +78,29 @@ instantiate_unary_float(Tanh) instantiate_unary_float(Round) instantiate_unary_int(BitwiseInvert) -instantiate_unary_all_same(Abs, complex64, complex64_t) -instantiate_unary_all_same(ArcCos, complex64, complex64_t) -instantiate_unary_all_same(ArcSin, complex64, complex64_t) -instantiate_unary_all_same(ArcTan, complex64, complex64_t) -instantiate_unary_all_same(Conjugate, complex64, complex64_t) -instantiate_unary_all_same(Cos, complex64, complex64_t) -instantiate_unary_all_same(Cosh, complex64, complex64_t) -instantiate_unary_all_same(Exp, complex64, complex64_t) -instantiate_unary_all_same(Log, complex64, complex64_t) -instantiate_unary_all_same(Log1p, complex64, complex64_t) -instantiate_unary_all_same(Log2, complex64, complex64_t) -instantiate_unary_all_same(Log10, complex64, complex64_t) -instantiate_unary_all_same(Negative, complex64, complex64_t) -instantiate_unary_all_same(Sign, complex64, complex64_t) -instantiate_unary_all_same(Sin, complex64, complex64_t) -instantiate_unary_all_same(Sinh, complex64, complex64_t) -instantiate_unary_all_same(Square, complex64, complex64_t) -instantiate_unary_all_same(Sqrt, complex64, complex64_t) -instantiate_unary_all_same(Rsqrt, complex64, complex64_t) -instantiate_unary_all_same(Tan, complex64, complex64_t) -instantiate_unary_all_same(Tanh, complex64, complex64_t) -instantiate_unary_all_same(Round, complex64, complex64_t) -instantiate_unary_all(Real, complex64, float32, complex64_t, float) -instantiate_unary_all(Imag, complex64, float32, complex64_t, float) +instantiate_unary_base_same(Abs, complex64, complex64_t) +instantiate_unary_base_same(ArcCos, complex64, complex64_t) +instantiate_unary_base_same(ArcSin, complex64, complex64_t) +instantiate_unary_base_same(ArcTan, complex64, complex64_t) +instantiate_unary_base_same(Conjugate, complex64, complex64_t) +instantiate_unary_base_same(Cos, complex64, complex64_t) +instantiate_unary_base_same(Cosh, complex64, complex64_t) +instantiate_unary_base_same(Exp, complex64, complex64_t) +instantiate_unary_base_same(Log, complex64, complex64_t) +instantiate_unary_base_same(Log1p, complex64, complex64_t) +instantiate_unary_base_same(Log2, complex64, complex64_t) +instantiate_unary_base_same(Log10, complex64, complex64_t) +instantiate_unary_base_same(Negative, complex64, complex64_t) +instantiate_unary_base_same(Sign, complex64, complex64_t) +instantiate_unary_base_same(Sin, complex64, complex64_t) +instantiate_unary_base_same(Sinh, complex64, complex64_t) +instantiate_unary_base_same(Square, complex64, complex64_t) +instantiate_unary_base_same(Sqrt, complex64, complex64_t) +instantiate_unary_base_same(Rsqrt, complex64, complex64_t) +instantiate_unary_base_same(Tan, complex64, complex64_t) +instantiate_unary_base_same(Tanh, complex64, complex64_t) +instantiate_unary_base_same(Round, complex64, complex64_t) +instantiate_unary_base(Real, complex64, float32, complex64_t, float) +instantiate_unary_base(Imag, complex64, float32, complex64_t, float) instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 0b821151e..22f2a1985 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = get_work_per_thread(b.dtype()); + work_per_thread = get_work_per_thread(b.dtype(), out.data_size()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -60,6 +60,8 @@ void ternary_op_gpu_inplace( } } else if (large) { kernel_name = "v2"; + } else if (work_per_thread > 1) { + kernel_name = "vn"; } else { kernel_name = "v"; } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 368e693a9..850c17376 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -43,8 +43,8 @@ void unary_op_gpu_inplace( int work_per_thread; std::string kernel_name; if (contig) { - work_per_thread = get_work_per_thread(in.dtype()); - kernel_name = (large ? "v2" : "v"); + work_per_thread = get_work_per_thread(in.dtype(), in.data_size()); + kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v")); } else { work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 576fb9107..a491521a0 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) { inline int get_work_per_thread(Dtype dtype) { return std::max(1, 8 / dtype.size()); } +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} inline size_t ceildiv(size_t n, size_t m) { return (n + m - 1) / m; From 24f89173d1d765bd906103e7526d510aca371db8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 7 Jun 2025 04:24:04 +0900 Subject: [PATCH 075/156] CUDA backend: matmul (#2241) --- mlx/backend/common/matmul.h | 78 ++++++ mlx/backend/cuda/CMakeLists.txt | 4 + mlx/backend/cuda/device.cpp | 14 +- mlx/backend/cuda/device.h | 14 + mlx/backend/cuda/matmul.cpp | 474 ++++++++++++++++++++++++++++++++ mlx/backend/cuda/primitives.cu | 2 - mlx/backend/metal/matmul.cpp | 65 +---- 7 files changed, 584 insertions(+), 67 deletions(-) create mode 100644 mlx/backend/common/matmul.h create mode 100644 mlx/backend/cuda/matmul.cpp diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h new file mode 100644 index 000000000..2e0261a30 --- /dev/null +++ b/mlx/backend/common/matmul.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include + +namespace mlx::core { + +inline std::tuple collapse_batches( + const array& a, + const array& b) { + // Get and check the shape for the batched dims + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; + if (A_bshape != B_bshape) { + std::ostringstream msg; + msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " + << a.shape() << ", B " << b.shape() << "."; + throw std::runtime_error(msg.str()); + } + + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); + + auto a_batch_strides = batch_strides[0]; + auto b_batch_strides = batch_strides[1]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + a_batch_strides.push_back(0); + b_batch_strides.push_back(0); + } + + return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides); +} + +inline std::tuple +collapse_batches(const array& a, const array& b, const array& c) { + // Get and check the shape for the batched dims + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; + Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; + if (A_bshape != B_bshape || A_bshape != C_bshape) { + std::ostringstream msg; + msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " + << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; + throw std::runtime_error(msg.str()); + } + + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; + + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + auto C_batch_stride = batch_strides[2]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + C_batch_stride.push_back(0); + } + + return std::make_tuple( + batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c991c2094..9eaf2a6c7 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp @@ -53,6 +54,9 @@ target_link_libraries(mlx PUBLIC $) find_package(CUDAToolkit REQUIRED) target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) +# Use cublasLt. +target_link_libraries(mlx PRIVATE CUDA::cublasLt) + # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index a28ffa35e..8a3d66c8e 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -34,14 +34,26 @@ CommandEncoder& DeviceStream::get_encoder() { } Device::Device(int device) : device_(device) { + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_)); // Validate the requirements of device. int attr = 0; - cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_); + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &attr, cudaDevAttrConcurrentManagedAccess, device_)); if (attr != 1) { throw std::runtime_error(fmt::format( "Device {} does not support synchronization in managed memory.", device_)); } + // The cublasLt handle is used by matmul. + make_current(); + cublasLtCreate(<_); +} + +Device::~Device() { + cublasLtDestroy(lt_); } void Device::make_current() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index a65a87d54..5b2cc0607 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -6,6 +6,7 @@ #include "mlx/backend/cuda/worker.h" #include "mlx/stream.h" +#include #include #include @@ -46,6 +47,7 @@ class DeviceStream { class Device { public: explicit Device(int device); + ~Device(); Device(const Device&) = delete; Device& operator=(const Device&) = delete; @@ -58,9 +60,21 @@ class Device { int cuda_device() const { return device_; } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + cublasLtHandle_t lt_handle() const { + return lt_; + } private: int device_; + int compute_capability_major_; + int compute_capability_minor_; + cublasLtHandle_t lt_; std::unordered_map streams_; }; diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp new file mode 100644 index 000000000..89247fd3e --- /dev/null +++ b/mlx/backend/cuda/matmul.cpp @@ -0,0 +1,474 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/matmul.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) + +void check_cublas_error(const char* name, cublasStatus_t err) { + if (err != CUBLAS_STATUS_SUCCESS) { + // TODO: Use cublasGetStatusString when it is widely available. + throw std::runtime_error( + fmt::format("{} failed with code: {}.", name, static_cast(err))); + } +} + +class MatMul { + public: + MatMul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) { + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + auto type = dtype_to_cuda_type(dtype); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( + &matmul_desc_, dtype_to_compute_type(dtype), type)); + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + cublasOperation_t op = CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op, + sizeof(cublasOperation_t))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &op, + sizeof(cublasOperation_t))); + + a_desc_ = create_matrix_layout( + type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); + b_desc_ = create_matrix_layout( + type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); + out_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); + + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + MatMul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + bool c_transposed, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride) + : MatMul( + device, + dtype, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride) { + auto type = dtype_to_cuda_type(dtype); + c_desc_ = create_matrix_layout( + type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); + } + + ~MatMul() { + cublasLtMatrixLayoutDestroy(a_desc_); + cublasLtMatrixLayoutDestroy(b_desc_); + cublasLtMatrixLayoutDestroy(c_desc_); + cublasLtMatrixLayoutDestroy(out_desc_); + cublasLtMatmulDescDestroy(matmul_desc_); + } + + void run( + cu::CommandEncoder& encoder, + void* out, + void* a, + void* b, + void* c = nullptr, + float alpha = 1, + float beta = 0) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + encoder.device().lt_handle(), + matmul_desc_, + a_desc_, + b_desc_, + out_desc_, + out_desc_, + pref_, + 1, + &heuristic_, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + + encoder.launch_kernel([&](cudaStream_t stream) { + CHECK_CUBLAS_ERROR(cublasLtMatmul( + encoder.device().lt_handle(), + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace.data(), + workspace.nbytes(), + stream)); + }); + } + + private: + cublasComputeType_t dtype_to_compute_type(Dtype dtype) { + switch (dtype) { + case uint8: + case uint16: + case int8: + case int16: + case int32: + return CUBLAS_COMPUTE_32I; + case float16: + case bfloat16: + return CUBLAS_COMPUTE_16F; + case float32: + return CUBLAS_COMPUTE_32F; + case float64: + case complex64: + return CUBLAS_COMPUTE_64F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); + } + } + + cudaDataType_t dtype_to_cuda_type(Dtype dtype) { + switch (dtype) { + case uint8: + return CUDA_R_8U; + case uint16: + return CUDA_R_16U; + case int8: + return CUDA_R_8I; + case int16: + return CUDA_R_16I; + case int32: + return CUDA_R_32I; + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); + } + } + + cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + cublasLtOrder_t order = + transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; + } + + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; +}; + +} // namespace cu + +namespace { + +std::tuple +check_transpose(std::vector& copies, const Stream& s, const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && stx == arr.shape(-1)) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return std::make_tuple(false, arr.shape(-1), arr_copy); + } +} + +} // namespace + +void Matmul::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Matmul::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 2); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); + + for (auto& temp : copies) { + encoder.add_temporary(temp); + } + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + + cu::MatMul matmul( + encoder.device(), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + matmul.run( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M * N, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc); + a_it.step(); + b_it.step(); + } +} + +void AddMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("AddMM::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 3); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto& c_pre = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); + auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre); + + for (auto& temp : copies) { + encoder.add_temporary(temp); + } + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] = + collapse_batches(a, b, c); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + c_batch_strides.back() == M * c.strides()[c.ndim() - 2] && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + c_batch_strides = {0}; + batch_shape = {1}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + + cu::MatMul matmul( + encoder.device(), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + c_transposed, + ldc, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back(), + c_batch_strides.back()); + + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + matmul.run( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M * N, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + c.data() + c.itemsize() * c_it.loc, + alpha_, + beta_); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 11de02c8e..fad2d76d3 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback( NO_GPU(Abs) NO_GPU(Add) -NO_GPU(AddMM) NO_GPU(ArcCos) NO_GPU(ArcCosh) NO_GPU(ArcSin) @@ -124,7 +123,6 @@ NO_GPU(LogicalOr) NO_GPU(LogAddExp) NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) -NO_GPU(Matmul) NO_GPU(Maximum) NO_GPU(Minimum) NO_GPU(Multiply) diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index e0ff44200..ed96d37ea 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -6,7 +6,7 @@ #include #include "mlx/backend/common/broadcasting.h" -#include "mlx/backend/common/utils.h" +#include "mlx/backend/common/matmul.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" @@ -21,69 +21,6 @@ namespace mlx::core { namespace { -inline auto collapse_batches(const array& a, const array& b) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - if (A_bshape != B_bshape) { - std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << "."; - throw std::runtime_error(msg.str()); - } - - Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; - Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; - - auto [batch_shape, batch_strides] = - collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); - - auto A_batch_stride = batch_strides[0]; - auto B_batch_stride = batch_strides[1]; - - if (batch_shape.empty()) { - batch_shape.push_back(1); - A_batch_stride.push_back(0); - B_batch_stride.push_back(0); - } - - return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride); -} - -inline auto collapse_batches(const array& a, const array& b, const array& c) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; - if (A_bshape != B_bshape || A_bshape != C_bshape) { - std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; - throw std::runtime_error(msg.str()); - } - - Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; - Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; - Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; - - auto [batch_shape, batch_strides] = collapse_contiguous_dims( - A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); - - auto A_batch_stride = batch_strides[0]; - auto B_batch_stride = batch_strides[1]; - auto C_batch_stride = batch_strides[2]; - - if (batch_shape.empty()) { - batch_shape.push_back(1); - A_batch_stride.push_back(0); - B_batch_stride.push_back(0); - C_batch_stride.push_back(0); - } - - return std::make_tuple( - batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); -} - std::tuple check_transpose( std::vector& copies, const Stream& s, From 2e8cf0b4506c200a5c2d199ecbbf655fdf4c2ce2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 6 Jun 2025 13:34:56 -0700 Subject: [PATCH 076/156] Change layernorms to two pass algorithm (#2246) --- benchmarks/python/layer_norm_bench.py | 54 +- mlx/backend/metal/kernels/layer_norm.metal | 472 ++++++++---------- mlx/backend/metal/normalization.cpp | 30 +- .../metal/scaled_dot_product_attention.cpp | 2 +- mlx/fast.cpp | 8 +- 5 files changed, 260 insertions(+), 306 deletions(-) diff --git a/benchmarks/python/layer_norm_bench.py b/benchmarks/python/layer_norm_bench.py index 69263835a..29925de0b 100644 --- a/benchmarks/python/layer_norm_bench.py +++ b/benchmarks/python/layer_norm_bench.py @@ -1,5 +1,7 @@ # Copyright © 2023-2024 Apple Inc. +from functools import partial + import mlx.core as mx import mlx.nn as nn from time_utils import time_fn @@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps): return y -def time_layer_norm(): +def time_layer_norm(N, dt): + L = 1024 f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x, w, b): + def layer_norm_loop(f, x, w, b): + for _ in range(32): + x = f(x, w, b) + return x + + time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b) + time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b) + + def layer_norm_grad_loop(g, x, w, b): gx, gw, gb = x, w, b for _ in range(32): gx, gw, gb = g(gx, gw, gb, y) return gx, gw, gb - time_fn(layer_norm_loop, g1, x, w, b) - time_fn(layer_norm_loop, g2, x, w, b) - time_fn(layer_norm_loop, mx.compile(g1), x, w, b) - time_fn(layer_norm_loop, mx.compile(g2), x, w, b) + time_fn(layer_norm_grad_loop, g1, x, w, b) + time_fn(layer_norm_grad_loop, g2, x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b) + time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b) f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() g1 = mx.grad(f1, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,)) - x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) - w = mx.random.uniform(shape=(4096,)).astype(mx.float16) - b = mx.random.uniform(shape=(4096,)).astype(mx.float16) - y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) + x = mx.random.uniform(shape=(8, L, N)).astype(dt) + w = mx.random.uniform(shape=(N,)).astype(dt) + b = mx.random.uniform(shape=(N,)).astype(dt) + y = mx.random.uniform(shape=(8, L, N)).astype(dt) mx.eval(x, w, b, y) - def layer_norm_loop(g, x): + def layer_norm_grad_x_loop(g, x): gx = x for _ in range(32): gx = g(gx, y) return gx - time_fn(layer_norm_loop, g1, x) - time_fn(layer_norm_loop, g2, x) - time_fn(layer_norm_loop, mx.compile(g1), x) - time_fn(layer_norm_loop, mx.compile(g2), x) + time_fn(layer_norm_grad_x_loop, g1, x) + time_fn(layer_norm_grad_x_loop, g2, x) + time_fn(layer_norm_grad_x_loop, mx.compile(g1), x) + time_fn(layer_norm_grad_x_loop, mx.compile(g2), x) if __name__ == "__main__": - time_layer_norm() + for dt in [mx.float32, mx.float16, mx.bfloat16]: + for n in [1024, 2048, 4096, 8192, 8192 + 1024]: + print(dt, n) + time_layer_norm(n, dt) diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 51570e48d..06b8be55f 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -9,7 +9,41 @@ using namespace metal; constant bool has_w [[function_constant(20)]]; -template +template +inline void initialize_buffer( + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + if (simd_group_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_lane_id + i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline void threadgroup_sum( + thread float* x, + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + for (int i = 0; i < N; i++) { + x[i] = simd_sum(x[i]); + } + if (simd_lane_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_group_id + i] = x[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N; i++) { + x[i] = xs[N * simd_lane_id + i]; + x[i] = simd_sum(x[i]); + } +} + +template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, @@ -23,90 +57,71 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - float thread_x[N_READS]; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + threadgroup float local_buffer[SIMD_SIZE] = {0}; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); + // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; + + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + normalizer += thread_x[i] * thread_x[i]; + } + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = - w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } -template +template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, @@ -121,71 +136,52 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + threadgroup float local_buffer[SIMD_SIZE]; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; @@ -208,7 +204,7 @@ template } } -template +template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, @@ -222,133 +218,96 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + float thread_w[N_READS] = {0}; + float thread_g[N_READS] = {0}; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - constexpr int SIMD_SIZE = 32; + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; - - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; + thread_w[i] = w[i * w_stride]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; - thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; + thread_g[i] = g[i]; + thread_w[i] = w[i * w_stride]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + factors[meanwg] += thread_w[i] * thread_g[i]; + factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; + factors[normalizer2] += thread_x[i] * thread_x[i]; + } + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } -template +template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, @@ -363,102 +322,69 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the accumulators - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + float wi = w[(i + r) * w_stride]; + float gi = g[i + r]; + float wg = wi * gi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; @@ -470,7 +396,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } @@ -482,7 +409,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c0901ccec..c53289828 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -255,12 +255,13 @@ void LayerNorm::eval_gpu( auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 6656; std::string op_name = "layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(out); auto& compute_encoder = d.get_command_encoder(s.index); @@ -272,7 +273,13 @@ void LayerNorm::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); @@ -372,12 +379,13 @@ void LayerNormVJP::eval_gpu( g, gb, "sum", plan, {0}, compute_encoder, d, s); } - const int simd_size = 32; - const int n_reads = RMS_N_READS; - const int looped_limit = RMS_LOOPED_LIMIT; + int simd_size = 32; + int n_reads = 8; + int looped_limit = 8192; std::string op_name = "vjp_layer_norm"; if (axis_size > looped_limit) { op_name += "_looped"; + n_reads = 4; } op_name += type_to_name(gx); @@ -394,7 +402,13 @@ void LayerNormVJP::eval_gpu( size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; - assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + if (threadgroup_size > kernel->maxTotalThreadsPerThreadgroup()) { + std::ostringstream msg; + msg << "[vjp_layer_norm] Threadgroup size " << threadgroup_size + << " is larger than the maximum allowed threadgroup size " + << kernel->maxTotalThreadsPerThreadgroup(); + throw std::runtime_error(msg.str()); + } size_t n_threads = n_rows * threadgroup_size; grid_dims = MTL::Size(n_threads, 1, 1); group_dims = MTL::Size(threadgroup_size, 1, 1); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index aad1a0018..096d6b906 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -369,7 +369,7 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); - const bool supports_sdpa_full = + const bool supports_sdpa_full = query_sequence_length > 8 && sdpa_full_supported_mask && sdpa_full_supported_head_dim; const bool supports_sdpa_vector = (query_sequence_length <= 8) && diff --git a/mlx/fast.cpp b/mlx/fast.cpp index eab22f14d..657c0aba8 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -231,13 +231,11 @@ array layer_norm( const std::vector& inputs) { auto x = astype(inputs[0], float32, s); - // Should I not be smart here and leave the double mean to simplify()? auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s); - auto mu2 = square(mu, s); - auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s); - auto v = subtract(x2, mu2, s); + auto xc = subtract(x, mu, s); + auto v = mean(square(xc, s), /* axis= */ -1, /* keepdims= */ true, s); - x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s)); + x = multiply(xc, rsqrt(add(v, array(eps, float32), s), s)); x = astype(x, out_type, s); // If the LN is affine then transform x according to the weight and bias From 1ca616844bc7434cb0186a302ed1afc6167970b3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 20:08:15 -0700 Subject: [PATCH 077/156] Fix unintuitive metal kernel caching (#2242) * Fix unintuitive metal kernel caching * alternative solution --- docs/src/dev/custom_metal_kernels.rst | 498 +++++++++--------- docs/src/dev/extensions.rst | 6 +- examples/extensions/axpby/axpby.cpp | 6 +- mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 344 +++++++++++- mlx/backend/metal/device.cpp | 66 ++- mlx/backend/metal/device.h | 16 +- mlx/backend/metal/nojit_kernels.cpp | 8 +- mlx/backend/metal/normalization.cpp | 4 +- .../metal/scaled_dot_product_attention.cpp | 6 +- mlx/backend/no_gpu/primitives.cpp | 13 + mlx/fast.cpp | 302 ----------- python/tests/test_fast.py | 35 ++ 13 files changed, 713 insertions(+), 593 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 3e92f2814..873b1e544 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs. Simple Example -------------- +.. currentmodule:: mlx.core + Let's write a custom kernel that computes ``exp`` elementwise: .. code-block:: python - def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::exp(tmp); - """ + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ - kernel = mx.fast.metal_kernel( - name="myexp", - input_names=["inp"], - output_names=["out"], - source=source, - ) + kernel = mx.fast.metal_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source, + ) + + def exp_elementwise(a: mx.array): outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise: b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) +Every time you make a kernel, a new Metal library is created and possibly +JIT compiled. To reduce the overhead from that, build the kernel once with +:func:`fast.metal_kernel` and then use it many times. + .. note:: - We are only required to pass the body of the Metal kernel in ``source``. + Only pass the body of the Metal kernel in ``source``. The function + signature is generated automatically. The full function signature will be generated using: @@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; -Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. -This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. -For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads +`_ +function. This means we will launch ``mx.prod(grid)`` threads, subdivided into +``threadgroup`` size threadgroups. For optimal performance, each thread group +dimension should be less than or equal to the corresponding grid dimension. -Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. +Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the +generated code for debugging purposes. Using Shape/Strides ------------------- -``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. -This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. -Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims -when indexing. +:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which +is ``True`` by default. This will copy the array inputs if needed +before the kernel is launched to ensure that the memory layout is row +contiguous. Generally this makes writing the kernel easier, since we don't +have to worry about gaps or the ordering of the dims when indexing. -If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each -input array ``a`` if any are present in ``source``. -We can then use MLX's built in indexing utils to fetch the right elements for each thread. +If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes +``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are +present in ``source``. We can then use MLX's built in indexing utils to fetch +the right elements for each thread. -Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: +Let's convert ``myexp`` above to support arbitrarily strided arrays without +relying on a copy from ``ensure_row_contiguous``: .. code-block:: python + + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + input_names=["inp"], + output_names=["out"], + source=source + ) def exp_elementwise(a: mx.array): - source = """ - uint elem = thread_position_in_grid.x; - // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - // Output arrays are always row contiguous - out[elem] = metal::exp(tmp); - """ - - kernel = mx.fast.metal_kernel( - name="myexp_strided", - input_names=["inp"], - output_names=["out"], - source=source - ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], @@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops: .. code-block:: python - def grid_sample_ref(x, grid): - N, H_in, W_in, _ = x.shape - ix = ((grid[..., 0] + 1) * W_in - 1) / 2 - iy = ((grid[..., 1] + 1) * H_in - 1) / 2 + def grid_sample_ref(x, grid): + N, H_in, W_in, _ = x.shape + ix = ((grid[..., 0] + 1) * W_in - 1) / 2 + iy = ((grid[..., 1] + 1) * H_in - 1) / 2 - ix_nw = mx.floor(ix).astype(mx.int32) - iy_nw = mx.floor(iy).astype(mx.int32) + ix_nw = mx.floor(ix).astype(mx.int32) + iy_nw = mx.floor(iy).astype(mx.int32) - ix_ne = ix_nw + 1 - iy_ne = iy_nw + ix_ne = ix_nw + 1 + iy_ne = iy_nw - ix_sw = ix_nw - iy_sw = iy_nw + 1 + ix_sw = ix_nw + iy_sw = iy_nw + 1 - ix_se = ix_nw + 1 - iy_se = iy_nw + 1 + ix_se = ix_nw + 1 + iy_se = iy_nw + 1 - nw = (ix_se - ix) * (iy_se - iy) - ne = (ix - ix_sw) * (iy_sw - iy) - sw = (ix_ne - ix) * (iy - iy_ne) - se = (ix - ix_nw) * (iy - iy_nw) + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) - I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] - I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] - I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] - I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] + I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] + I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] + I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] + I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] - mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) - mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) - mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) - mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) + mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) + mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) + mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) + mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) - I_nw *= mask_nw[..., None] - I_ne *= mask_ne[..., None] - I_sw *= mask_sw[..., None] - I_se *= mask_se[..., None] + I_nw *= mask_nw[..., None] + I_ne *= mask_ne[..., None] + I_sw *= mask_sw[..., None] + I_se *= mask_se[..., None] - output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se + output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se - return output + return output -Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` +Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` to write a fast GPU kernel for both the forward and backward passes. First we'll implement the forward pass as a fused kernel: .. code-block:: python - @mx.custom_function - def grid_sample(x, grid): + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + int gH = grid_shape[1]; + int gW = grid_shape[2]; - assert x.ndim == 4, "`x` must be 4D." - assert grid.ndim == 4, "`grid` must be 4D." + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - B, _, _, C = x.shape - _, gN, gM, D = grid.shape - out_shape = (B, gN, gM, C) + uint grid_idx = elem / C * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - assert D == 2, "Last dim of `grid` must be size 2." + int ix_nw = floor(ix); + int iy_nw = floor(iy); - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - uint grid_idx = elem / C * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int batch_idx = elem / C / gH / gW * b_stride; + int channel_idx = elem % C; + int base_idx = batch_idx + channel_idx; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; + T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; + T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; + T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; + I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; + I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; + I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; + """ - int batch_idx = elem / C / gH / gW * b_stride; - int channel_idx = elem % C; - int base_idx = batch_idx + channel_idx; + kernel = mx.fast.metal_kernel( + name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], + source=source, + ) - T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; - T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; - T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; - T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; + @mx.custom_function + def grid_sample(x, grid): - I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; - I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; - I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; - I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; + assert x.ndim == 4, "`x` must be 4D." + assert grid.ndim == 4, "`grid` must be 4D." - out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; - """ - kernel = mx.fast.metal_kernel( - name="grid_sample", - input_names=["x", "grid"], - output_names=["out"], - source=source, - ) - outputs = kernel( - inputs=[x, grid], - template=[("T", x.dtype)], - output_shapes=[out_shape], - output_dtypes=[x.dtype], - grid=(np.prod(out_shape), 1, 1), - threadgroup=(256, 1, 1), - ) - return outputs[0] + B, _, _, C = x.shape + _, gN, gM, D = grid.shape + out_shape = (B, gN, gM, C) + + assert D == 2, "Last dim of `grid` must be size 2." + + outputs = kernel( + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], + grid=(np.prod(out_shape), 1, 1), + threadgroup=(256, 1, 1), + ) + return outputs[0] For a reasonably sized input such as: .. code-block:: python - x.shape = (8, 1024, 1024, 64) - grid.shape = (8, 256, 256, 2) + x.shape = (8, 1024, 1024, 64) + grid.shape = (8, 256, 256, 2) On an M1 Max, we see a big performance improvement: @@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement: Grid Sample VJP --------------- -Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define -its custom vjp transform so MLX can differentiate it. +Since we decorated ``grid_sample`` with :func:`custom_function`, we can now +define its custom vjp transform so MLX can differentiate it. The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so -requires a few extra ``mx.fast.metal_kernel`` features: +requires a few extra :func:`fast.metal_kernel` features: * ``init_value=0`` Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. @@ -299,128 +316,129 @@ We can then implement the backwards pass as follows: .. code-block:: python - @grid_sample.vjp - def grid_sample_vjp(primals, cotangent, _): - x, grid = primals - B, _, _, C = x.shape - _, gN, gM, D = grid.shape + source = """ + uint elem = thread_position_in_grid.x; + int H = x_shape[1]; + int W = x_shape[2]; + int C = x_shape[3]; + // Pad C to the nearest larger simdgroup size multiple + int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; - assert D == 2, "Last dim of `grid` must be size 2." + int gH = grid_shape[1]; + int gW = grid_shape[2]; - source = """ - uint elem = thread_position_in_grid.x; - int H = x_shape[1]; - int W = x_shape[2]; - int C = x_shape[3]; - // Pad C to the nearest larger simdgroup size multiple - int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; + int w_stride = C; + int h_stride = W * w_stride; + int b_stride = H * h_stride; - int gH = grid_shape[1]; - int gW = grid_shape[2]; + uint grid_idx = elem / C_padded * 2; + float ix = ((grid[grid_idx] + 1) * W - 1) / 2; + float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; - int w_stride = C; - int h_stride = W * w_stride; - int b_stride = H * h_stride; + int ix_nw = floor(ix); + int iy_nw = floor(iy); - uint grid_idx = elem / C_padded * 2; - float ix = ((grid[grid_idx] + 1) * W - 1) / 2; - float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; - int ix_nw = floor(ix); - int iy_nw = floor(iy); + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + int batch_idx = elem / C_padded / gH / gW * b_stride; + int channel_idx = elem % C_padded; + int base_idx = batch_idx + channel_idx; - T nw = (ix_se - ix) * (iy_se - iy); - T ne = (ix - ix_sw) * (iy_sw - iy); - T sw = (ix_ne - ix) * (iy - iy_ne); - T se = (ix - ix_nw) * (iy - iy_nw); + T gix = T(0); + T giy = T(0); + if (channel_idx < C) { + int cot_index = elem / C_padded * C + channel_idx; + T cot = cotangent[cot_index]; + if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { + int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); - int batch_idx = elem / C_padded / gH / gW * b_stride; - int channel_idx = elem % C_padded; - int base_idx = batch_idx + channel_idx; + T I_nw = x[offset]; + gix -= I_nw * (iy_se - iy) * cot; + giy -= I_nw * (ix_se - ix) * cot; + } + if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { + int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); - T gix = T(0); - T giy = T(0); - if (channel_idx < C) { - int cot_index = elem / C_padded * C + channel_idx; - T cot = cotangent[cot_index]; - if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { - int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); + T I_ne = x[offset]; + gix += I_ne * (iy_sw - iy) * cot; + giy -= I_ne * (ix - ix_sw) * cot; + } + if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { + int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); - T I_nw = x[offset]; - gix -= I_nw * (iy_se - iy) * cot; - giy -= I_nw * (ix_se - ix) * cot; - } - if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { - int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); + T I_sw = x[offset]; + gix -= I_sw * (iy - iy_ne) * cot; + giy += I_sw * (ix_ne - ix) * cot; + } + if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { + int offset = base_idx + iy_se * h_stride + ix_se * w_stride; + atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); - T I_ne = x[offset]; - gix += I_ne * (iy_sw - iy) * cot; - giy -= I_ne * (ix - ix_sw) * cot; - } - if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { - int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); + T I_se = x[offset]; + gix += I_se * (iy - iy_nw) * cot; + giy += I_se * (ix - ix_nw) * cot; + } + } - T I_sw = x[offset]; - gix -= I_sw * (iy - iy_ne) * cot; - giy += I_sw * (ix_ne - ix) * cot; - } - if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { - int offset = base_idx + iy_se * h_stride + ix_se * w_stride; - atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); + T gix_mult = W / 2; + T giy_mult = H / 2; - T I_se = x[offset]; - gix += I_se * (iy - iy_nw) * cot; - giy += I_se * (ix - ix_nw) * cot; - } - } + // Reduce across each simdgroup first. + // This is much faster than relying purely on atomics. + gix = simd_sum(gix); + giy = simd_sum(giy); - T gix_mult = W / 2; - T giy_mult = H / 2; + if (thread_index_in_simdgroup == 0) { + atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); + atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); + } + """ + kernel = mx.fast.metal_kernel( + name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], + source=source, + atomic_outputs=True, + ) - // Reduce across each simdgroup first. - // This is much faster than relying purely on atomics. - gix = simd_sum(gix); - giy = simd_sum(giy); + @grid_sample.vjp + def grid_sample_vjp(primals, cotangent, _): + x, grid = primals + B, _, _, C = x.shape + _, gN, gM, D = grid.shape - if (thread_index_in_simdgroup == 0) { - atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); - atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); - } - """ - kernel = mx.fast.metal_kernel( - name="grid_sample_grad", - input_names=["x", "grid", "cotangent"], - output_names=["x_grad", "grid_grad"], - source=source, - atomic_outputs=True, - ) - # pad the output channels to simd group size - # so that our `simd_sum`s don't overlap. - simdgroup_size = 32 - C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size - grid_size = B * gN * gM * C_padded - outputs = kernel( - inputs=[x, grid, cotangent], - template=[("T", x.dtype)], - output_shapes=[x.shape, grid.shape], - output_dtypes=[x.dtype, x.dtype], - grid=(grid_size, 1, 1), - threadgroup=(256, 1, 1), - init_value=0, - ) - return outputs[0], outputs[1] + assert D == 2, "Last dim of `grid` must be size 2." + + # pad the output channels to simd group size + # so that our `simd_sum`s don't overlap. + simdgroup_size = 32 + C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size + grid_size = B * gN * gM * C_padded + outputs = kernel( + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], + grid=(grid_size, 1, 1), + threadgroup=(256, 1, 1), + init_value=0, + ) + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 2aef28f99..03f1c2163 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -397,11 +397,11 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 291246617..9ba933483 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -172,11 +172,11 @@ void Axpby::eval_gpu( kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 6b4b70d47..593b79384 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index ea4f258cc..161503a0e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,12 +1,326 @@ // Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { +struct CustomKernelCache { + std::unordered_map libraries; +}; + +static CustomKernelCache& cache() { + static CustomKernelCache cache_; + return cache_; +}; + +std::string write_signature( + std::string func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& attributes, + const std::vector& shape_infos, + bool atomic_outputs) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 16384); + kernel_source += header; + // Auto-generate a function signature based on `template_args` + // and the dtype/shape of the arrays passed as `inputs`. + if (!template_args.empty()) { + kernel_source += "template <"; + int i = 0; + for (const auto& [name, arg] : template_args) { + std::string param_type; + if (std::holds_alternative(arg)) { + param_type = "int"; + } else if (std::holds_alternative(arg)) { + param_type = "bool"; + } else if (std::holds_alternative(arg)) { + param_type = "typename"; + } + if (i > 0) { + kernel_source += ", "; + } + kernel_source += param_type; + kernel_source += " "; + kernel_source += name; + i++; + } + kernel_source += ">\n"; + } + kernel_source += "[[kernel]] void "; + kernel_source += func_name; + kernel_source += "(\n"; + + int index = 0; + constexpr int max_constant_array_size = 8; + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + auto dtype = get_type_string(arr.dtype()); + std::string location = + arr.size() < max_constant_array_size ? "constant" : "device"; + std::string ref = arr.ndim() == 0 ? "&" : "*"; + kernel_source += " const "; + kernel_source += location; + kernel_source += " "; + kernel_source += dtype; + kernel_source += ref; + kernel_source += " "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]],\n"; + index++; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += + (" const constant int* " + name + "_shape [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].strides) { + kernel_source += + (" const constant int64_t* " + name + "_strides [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].ndim) { + kernel_source += + (" const constant int& " + name + "_ndim [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + } + } + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " device "; + auto type_string = get_type_string(dtype); + if (atomic_outputs) { + kernel_source += "atomic<"; + } + kernel_source += type_string; + if (atomic_outputs) { + kernel_source += ">"; + } + kernel_source += "* "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]]"; + if (index < inputs.size() + output_names.size() - 1 || + attributes.size() > 0) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + + index = 0; + for (const auto& attr : attributes) { + kernel_source += attr; + if (index < attributes.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + kernel_source += source; + kernel_source += "\n}\n"; + return kernel_source; +} + +std::string write_template( + const std::vector>& template_args) { + std::ostringstream template_def; + template_def << "<"; + int i = 0; + for (const auto& [name, arg] : template_args) { + if (i > 0) { + template_def << ", "; + } + if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << get_type_string(std::get(arg)); + } + i++; + } + template_def << ">"; + return template_def.str(); +} + +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"threads_per_threadgroup", "uint3"}, + }; + + std::vector attributes; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); + } + } + + return [=, + shape_infos = std::move(shape_infos), + attributes = std::move(attributes)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } + + std::string kernel_name = "custom_kernel_" + name; + std::string template_def = ""; + if (!template_args.empty()) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args); + auto template_hash = + std::regex_replace(template_def, disallowed_chars, "_"); + template_hash.pop_back(); + kernel_name += "_"; + kernel_name += template_hash; + } + + std::string kernel_source = write_signature( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + attributes, + shape_infos, + atomic_outputs); + + if (!template_args.empty()) { + template_def = kernel_name + template_def; + kernel_source += "\ntemplate [[host_name(\""; + kernel_source += kernel_name; + kernel_source += "\")]] [[kernel]] decltype("; + kernel_source += template_def; + kernel_source += ") "; + kernel_source += template_def; + kernel_source += ";\n"; + } + + if (verbose) { + std::cout << "Generated source code for `" << name << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + std::move(inputs)); + }; +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -39,9 +353,23 @@ void CustomKernel::eval_gpu( } auto& d = metal::device(s.device); - const auto& lib_name = name_; - auto lib = - d.get_library(lib_name, [this] { return metal::utils() + source_; }); + + { + // Clear kernels from the device library cache if needed + auto& kernel_cache = cache(); + if (auto it = kernel_cache.libraries.find(name_); + it != kernel_cache.libraries.end()) { + if (it->second != source_) { + auto& d = metal::device(s.device); + d.clear_library(name_); + it->second = source_; + } + } else { + kernel_cache.libraries.emplace(name_, source_); + } + } + + auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -73,6 +401,16 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; + auto tg_size = tx * ty * tz; + auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup(); + if (tg_size > max_tg_size) { + std::ostringstream msg; + msg << "Thread group size (" << tg_size << ") is greater than " + << " the maximum allowed threads per threadgroup (" << max_tg_size + << ")."; + throw std::invalid_argument(msg.str()); + } + const auto [gx, gy, gz] = grid_; MTL::Size group_dims = MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index ebc3cc77f..425274361 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -295,7 +295,7 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_default_library(device_)}}; + default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); auto arch = arch_.back(); switch (arch) { @@ -326,11 +326,11 @@ Device::Device() { Device::~Device() { auto pool = new_scoped_memory_pool(); - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); + for (auto& [l, kernel_map] : library_kernels_) { + l->release(); + for (auto& [_, k] : kernel_map) { + k->release(); + } } stream_map_.clear(); device_->release(); @@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) { return *stream.encoder; } -void Device::register_library( - const std::string& lib_name, - const std::string& lib_path) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - auto new_lib = load_library(device_, lib_name, lib_path.c_str()); - library_map_.insert({lib_name, new_lib}); +MTL::Library* Device::get_library( + const std::string& name, + const std::string& path /* = "" */) { + { + std::shared_lock rlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } } + + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + + auto new_lib = load_library(device_, name, path.c_str()); + library_map_.insert({name, new_lib}); + return new_lib; } MTL::Library* Device::build_library_(const std::string& source_string) { @@ -649,6 +660,19 @@ MTL::Library* Device::get_library( return mtl_lib; } +void Device::clear_library(const std::string& name) { + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + auto kernel_map_it = library_kernels_.find(it->second); + for (auto& [_, kernel] : kernel_map_it->second) { + kernel->release(); + } + library_kernels_.erase(kernel_map_it); + it->second->release(); + library_map_.erase(it); + } +} + MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { @@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_( std::unique_lock wlock(kernel_mtx_); // Try loading again to avoid loading twice + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { return it->second; } @@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel( std::shared_lock lock(kernel_mtx_); // Look for cached kernel + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { return it->second; } @@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, - const std::string& lib_name /* = "mlx" */, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - const auto& kname = hash_name.size() == 0 ? base_name : hash_name; - { - // Multiple readers allowed - std::shared_lock lock(kernel_mtx_); - - // Look for cached kernel - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; - } - } - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_(lib_name); - return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); + return get_kernel( + base_name, default_library_, hash_name, func_consts, linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 660ba65e2..5bfcc6649 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -187,14 +187,16 @@ class Device { CommandEncoder& get_command_encoder(int index); void end_encoding(int index); - void register_library( - const std::string& lib_name, - const std::string& lib_path = ""); + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); MTL::Library* get_library( const std::string& name, const std::function& builder); + void clear_library(const std::string& name); + MTL::ComputePipelineState* get_kernel( const std::string& base_name, MTL::Library* mtl_lib, @@ -204,7 +206,6 @@ class Device { MTL::ComputePipelineState* get_kernel( const std::string& base_name, - const std::string& lib_name = "mlx", const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); @@ -258,10 +259,13 @@ class Device { std::unordered_map stream_map_; std::shared_mutex kernel_mtx_; - std::unordered_map kernel_map_; - std::shared_mutex library_mtx_; std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; int max_ops_per_buffer_; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 8da147971..b1478d33b 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( int, int, int) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( @@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( @@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel( const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string&) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( @@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c53289828..d570bf3c0 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -395,7 +395,7 @@ void LayerNormVJP::eval_gpu( }; { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 096d6b906..eef279d1d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; @@ -180,7 +180,7 @@ void sdpa_vector( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -281,7 +281,7 @@ void sdpa_vector_2pass( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 409aa2c89..849cbf83e 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -2,6 +2,7 @@ #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + } // namespace fast namespace distributed { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 657c0aba8..210c7f729 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,10 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include -#include -#include "mlx/backend/common/compiled.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" @@ -1027,303 +1024,4 @@ std::vector AffineQuantize::output_shapes( } } -std::string write_signature( - std::string func_name, - const std::string& header, - const std::string& source, - const std::vector& input_names, - const std::vector& inputs, - const std::vector& output_names, - const std::vector& output_dtypes, - const std::vector>& template_args, - const std::vector& attributes, - const std::vector& shape_infos, - bool atomic_outputs) { - std::string kernel_source; - kernel_source.reserve(header.size() + source.size() + 16384); - kernel_source += header; - // Auto-generate a function signature based on `template_args` - // and the dtype/shape of the arrays passed as `inputs`. - if (!template_args.empty()) { - kernel_source += "template <"; - int i = 0; - for (const auto& [name, arg] : template_args) { - std::string param_type; - if (std::holds_alternative(arg)) { - param_type = "int"; - } else if (std::holds_alternative(arg)) { - param_type = "bool"; - } else if (std::holds_alternative(arg)) { - param_type = "typename"; - } - if (i > 0) { - kernel_source += ", "; - } - kernel_source += param_type; - kernel_source += " "; - kernel_source += name; - i++; - } - kernel_source += ">\n"; - } - kernel_source += "[[kernel]] void "; - kernel_source += func_name; - kernel_source += "(\n"; - - int index = 0; - constexpr int max_constant_array_size = 8; - // Add inputs - for (int i = 0; i < inputs.size(); ++i) { - const auto& name = input_names[i]; - const auto& arr = inputs[i]; - auto dtype = get_type_string(arr.dtype()); - std::string location = - arr.size() < max_constant_array_size ? "constant" : "device"; - std::string ref = arr.ndim() == 0 ? "&" : "*"; - kernel_source += " const "; - kernel_source += location; - kernel_source += " "; - kernel_source += dtype; - kernel_source += ref; - kernel_source += " "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]],\n"; - index++; - // Add input shape, strides and ndim if present in the source - if (arr.ndim() > 0) { - if (shape_infos[i].shape) { - kernel_source += - (" const constant int* " + name + "_shape [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].strides) { - kernel_source += - (" const constant int64_t* " + name + "_strides [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].ndim) { - kernel_source += - (" const constant int& " + name + "_ndim [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - } - } - // Add outputs - for (int i = 0; i < output_names.size(); ++i) { - const auto& name = output_names[i]; - const auto& dtype = output_dtypes[i]; - kernel_source += " device "; - auto type_string = get_type_string(dtype); - if (atomic_outputs) { - kernel_source += "atomic<"; - } - kernel_source += type_string; - if (atomic_outputs) { - kernel_source += ">"; - } - kernel_source += "* "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - - index = 0; - for (const auto& attr : attributes) { - kernel_source += attr; - if (index < attributes.size() - 1) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - kernel_source += source; - kernel_source += "\n}\n"; - return kernel_source; -} - -std::string write_template( - const std::vector>& template_args) { - std::ostringstream template_def; - template_def << "<"; - int i = 0; - for (const auto& [name, arg] : template_args) { - if (i > 0) { - template_def << ", "; - } - if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << get_type_string(std::get(arg)); - } - i++; - } - template_def << ">"; - return template_def.str(); -} - -MetalKernelFunction metal_kernel( - const std::string& name, - const std::vector& input_names, - const std::vector& output_names, - const std::string& source, - const std::string& header /* = "" */, - bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { - if (output_names.empty()) { - throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); - } - std::vector shape_infos; - for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; - shape_infos.push_back(shape_info); - } - const std::vector> metal_attributes = { - {"dispatch_quadgroups_per_threadgroup", "uint"}, - {"dispatch_simdgroups_per_threadgroup", "uint"}, - {"dispatch_threads_per_threadgroup", "uint3"}, - {"grid_origin", "uint3"}, - {"grid_size", "uint3"}, - {"quadgroup_index_in_threadgroup", "uint"}, - {"quadgroups_per_threadgroup", "uint"}, - {"simdgroup_index_in_threadgroup", "uint"}, - {"simdgroups_per_threadgroup", "uint"}, - {"thread_execution_width", "uint"}, - {"thread_index_in_quadgroup", "uint"}, - {"thread_index_in_simdgroup", "uint"}, - {"thread_index_in_threadgroup", "uint"}, - {"thread_position_in_grid", "uint3"}, - {"thread_position_in_threadgroup", "uint3"}, - {"threadgroup_position_in_grid", "uint3"}, - {"threadgroups_per_grid", "uint3"}, - {"threads_per_grid", "uint3"}, - {"threads_per_simdgroup", "uint"}, - {"threads_per_threadgroup", "uint3"}, - }; - - std::vector attributes; - for (const auto& [attr, dtype] : metal_attributes) { - if (source.find(attr) != std::string::npos) { - attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); - } - } - - return [=, - shape_infos = std::move(shape_infos), - attributes = std::move(attributes)]( - const std::vector& inputs, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::vector>& - template_args = {}, - std::optional init_value = std::nullopt, - bool verbose = false, - StreamOrDevice s_ = {}) { - if (inputs.size() != input_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `inputs` to have size " - << input_names.size() << " but got size " << inputs.size() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_shapes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_shapes` to have size " - << output_names.size() << " but got size " << output_shapes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_dtypes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_dtypes` to have size " - << output_names.size() << " but got size " << output_dtypes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - if (s.device != Device::gpu) { - throw std::invalid_argument("[metal_kernel] Only supports the GPU."); - } - - std::ostringstream func_name; - std::string template_def = ""; - std::string hash_key = ""; - if (!template_args.empty()) { - std::regex disallowed_chars("\\<|\\>|(, )"); - template_def = write_template(template_args); - hash_key = std::regex_replace(template_def, disallowed_chars, "_"); - hash_key.pop_back(); - } - func_name << "custom_kernel_" << name << hash_key; - std::string kernel_name = func_name.str(); - - std::string kernel_source = write_signature( - kernel_name, - header, - source, - input_names, - inputs, - output_names, - output_dtypes, - template_args, - attributes, - shape_infos, - atomic_outputs); - - if (!template_args.empty()) { - template_def = kernel_name + template_def; - kernel_source += "\ntemplate [[host_name(\""; - kernel_source += kernel_name; - kernel_source += "\")]] [[kernel]] decltype("; - kernel_source += template_def; - kernel_source += ") "; - kernel_source += template_def; - kernel_source += ";\n"; - } - - if (verbose) { - std::cout << "Generated source code for `" << name << "`:" << std::endl - << "```" << std::endl - << kernel_source << std::endl - << "```" << std::endl; - } - - return array::make_arrays( - std::move(output_shapes), - std::move(output_dtypes), - std::make_shared( - s, - std::move(kernel_name), - std::move(kernel_source), - grid, - threadgroup, - shape_infos, - ensure_row_contiguous, - init_value), - std::move(inputs)); - }; -} - } // namespace mlx::core::fast diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 2c90a3755..59c2fc3ef 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase): )[0] self.assertEqual(out.item(), 2) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_caching(self): + def call_kernel(a: mx.array, source): + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["inp"], + output_names=["out"], + source=source, + ) + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.random.normal(shape=(32,)) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 0.0; + """ + + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.zeros_like(out))) + + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = 1.0; + """ + out = call_kernel(a, source) + self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + if __name__ == "__main__": unittest.main() From 5866b3857bb46b2f41368b69a2df2b2f4c874231 Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Sat, 7 Jun 2025 16:12:08 +0300 Subject: [PATCH 078/156] Refactor the lu test (#2250) Signed-off-by: Emmanuel Ferdman --- python/tests/test_linalg.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f5eeda837..764d11f6e 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -359,36 +359,6 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix - def test_lu(self): - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array(0.0), stream=mx.cpu) - - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) - - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) - - # Test 3x3 matrix - a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - - # Test batch dimension - a = mx.broadcast_to(a, (5, 5, 3, 3)) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - L = mx.take_along_axis(L, P[..., None], axis=-2) - self.assertTrue(mx.allclose(L @ U, a)) - - # Test non-square matrix - a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - - a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5} From f8bad606099169e486ef5c8761a4f2a9d158245b Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 9 Jun 2025 22:45:08 +0900 Subject: [PATCH 079/156] CUDA backend: unary ops (#2158) --- mlx/backend/common/copy.h | 4 +- mlx/backend/common/unary.h | 26 ++ mlx/backend/cpu/unary.h | 21 +- mlx/backend/cuda/CMakeLists.txt | 1 + .../cuda/iterators/general_iterator.cuh | 121 ++++++ mlx/backend/cuda/kernel_utils.cuh | 20 + mlx/backend/cuda/kernels/cucomplex_math.cuh | 240 ++++++++++++ mlx/backend/cuda/kernels/fp16_math.cuh | 72 ++++ mlx/backend/cuda/kernels/unary_ops.cuh | 349 ++++++++++++++++++ mlx/backend/cuda/kernels/utils.cuh | 43 +++ mlx/backend/cuda/primitives.cu | 32 -- mlx/backend/cuda/unary.cu | 196 ++++++++++ mlx/backend/metal/unary.cpp | 19 +- 13 files changed, 1074 insertions(+), 70 deletions(-) create mode 100644 mlx/backend/common/unary.h create mode 100644 mlx/backend/cuda/iterators/general_iterator.cuh create mode 100644 mlx/backend/cuda/kernels/cucomplex_math.cuh create mode 100644 mlx/backend/cuda/kernels/unary_ops.cuh create mode 100644 mlx/backend/cuda/kernels/utils.cuh create mode 100644 mlx/backend/cuda/unary.cu diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 0c9f28c94..c23d2e79a 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/array.h" +#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types // have the same size, then the input buffer can hold the output. - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (is_donatable(in, out)) { out.copy_shared_buffer(in); return true; } else { diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h new file mode 100644 index 000000000..a27a1f45c --- /dev/null +++ b/mlx/backend/common/unary.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +inline void set_unary_output_data(const array& in, array& out) { + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(allocator::malloc(out.nbytes())); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index fa539541c..14c1dd479 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -2,32 +2,13 @@ #pragma once -#include "mlx/allocator.h" -#include "mlx/array.h" -#include "mlx/backend/common/utils.h" +#include "mlx/backend/common/unary.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/utils.h" namespace mlx::core { -void set_unary_output_data(const array& in, array& out) { - if (in.flags().contiguous) { - if (is_donatable(in, out)) { - out.copy_shared_buffer(in); - } else { - auto size = in.data_size(); - out.set_data( - allocator::malloc(size * out.itemsize()), - size, - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } -} - template void unary_op(const T* a, U* out, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9eaf2a6c7..cd73843bf 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/iterators/general_iterator.cuh b/mlx/backend/cuda/iterators/general_iterator.cuh new file mode 100644 index 000000000..3c8c098c3 --- /dev/null +++ b/mlx/backend/cuda/iterators/general_iterator.cuh @@ -0,0 +1,121 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/cuda/kernel_utils.cuh" + +namespace mlx::core::cu { + +// Iterating non-contiguous array. +template +class general_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) + : super_t(it), + index_(index), + ndim_(ndim), + shape_(cuda::std::move(shape)), + strides_(cuda::std::move(strides)) {} + + __host__ __device__ IdxT index() const { + return index_; + } + + __host__ __device__ const Shape& shape() const { + return shape_; + } + + __host__ __device__ const Strides& strides() const { + return strides_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const general_iterator& other) const { + return this->base() == other.base() && this->index() == other.index(); + } + + __host__ __device__ void advance(difference_type n) { + this->index_ += n; + } + + __host__ __device__ void increment() { + this->index_ += 1; + } + + __host__ __device__ void decrement() { + this->index_ -= 1; + } + + __host__ __device__ difference_type + distance_to(const general_iterator& other) const { + _CCCL_ASSERT( + this->base() == other.base(), + "Underlying iterator must point to same base iterator"); + return other.index() - this->index(); + } + + // The dereference is device-only to avoid accidental running in host. + __device__ typename super_t::reference dereference() const { + IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_); + return *(this->base() + offset); + } + + IdxT index_; + int ndim_; + Shape shape_; + Strides strides_; +}; + +template +__host__ __device__ auto make_general_iterator( + Iterator it, + IdxT index, + int ndim, + Shape shape, + Strides strides) { + return general_iterator( + it, index, ndim, cuda::std::move(shape), cuda::std::move(strides)); +} + +template +auto make_general_iterator( + Iterator it, + const std::vector& shape, + const std::vector& strides) { + return make_general_iterator( + it, 0, shape.size(), const_param(shape), const_param(strides)); +} + +template +auto make_general_iterators( + Iterator it, + IdxT size, + const std::vector& shape, + const std::vector& strides) { + auto ndim = shape.size(); + auto shape_arg = const_param(shape); + auto strides_arg = const_param(strides); + return std::make_pair( + make_general_iterator(it, 0, ndim, shape_arg, strides_arg), + make_general_iterator(it, size, ndim, shape_arg, strides_arg)); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 67ac47449..6430b8c59 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -7,10 +7,12 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/kernels/utils.cuh" #include #include #include +#include namespace mlx::core { @@ -38,6 +40,24 @@ struct CTypeToCudaType { template using cuda_type_t = typename CTypeToCudaType::type; +// Type traits for detecting floating numbers. +template +inline constexpr bool is_floating_v = + cuda::std::is_same_v || cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; + +// Utility to copy data from vector to array in host. +template +inline cuda::std::array const_param(const std::vector& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + cuda::std::array result; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + // Compute the grid and block dimensions, check backend/common/utils.h for docs. dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); diff --git a/mlx/backend/cuda/kernels/cucomplex_math.cuh b/mlx/backend/cuda/kernels/cucomplex_math.cuh new file mode 100644 index 000000000..612650c06 --- /dev/null +++ b/mlx/backend/cuda/kernels/cucomplex_math.cuh @@ -0,0 +1,240 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2017-2024 The Simons Foundation, Inc. +// +// FINUFFT is licensed under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the +// License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h + +#pragma once + +#include + +// This header provides some helper functions for cuComplex types. +// It mainly wraps existing CUDA implementations to provide operator overloads +// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are +// all provided by CUDA + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCadd(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCsub(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCmul(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) { + return cuCdiv(a, b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) { + double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b)); + double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b)); + return make_cuDoubleComplex(r, i); +} + +__forceinline__ __host__ __device__ bool operator==( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b); +} + +__forceinline__ __host__ __device__ bool operator!=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return !(a == b); +} + +__forceinline__ __host__ __device__ bool operator>( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a)); + double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b)); + return mag_a > mag_b; +} + +__forceinline__ __host__ __device__ bool operator>=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return a > b || a == b; +} + +__forceinline__ __host__ __device__ bool operator<( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return b > a; +} + +__forceinline__ __host__ __device__ bool operator<=( + const cuDoubleComplex& a, + const cuDoubleComplex& b) { + return b > a || a == b; +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator+(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator-(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator*(double a, const cuDoubleComplex& b) { + return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b)); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(const cuDoubleComplex& a, double b) { + return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b); +} + +__forceinline__ __host__ __device__ cuDoubleComplex +operator/(double a, const cuDoubleComplex& b) { + double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b); + return make_cuDoubleComplex( + (a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCaddf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCsubf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCmulf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(const cuFloatComplex& a, const cuFloatComplex& b) { + return cuCdivf(a, b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator%(const cuFloatComplex& a, const cuFloatComplex& b) { + float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b)); + float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b)); + return make_cuFloatComplex(r, i); +} + +__forceinline__ __host__ __device__ bool operator==( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b); +} + +__forceinline__ __host__ __device__ bool operator!=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return !(a == b); +} + +__forceinline__ __host__ __device__ bool operator>( + const cuFloatComplex& a, + const cuFloatComplex& b) { + float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a)); + float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b)); + return mag_a > mag_b; +} + +__forceinline__ __host__ __device__ bool operator>=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return a > b || a == b; +} + +__forceinline__ __host__ __device__ bool operator<( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return b > a; +} + +__forceinline__ __host__ __device__ bool operator<=( + const cuFloatComplex& a, + const cuFloatComplex& b) { + return b > a || a == b; +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator+(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator-(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator*(float a, const cuFloatComplex& b) { + return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b)); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(const cuFloatComplex& a, float b) { + return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b); +} + +__forceinline__ __host__ __device__ cuFloatComplex +operator/(float a, const cuFloatComplex& b) { + float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b); + return make_cuFloatComplex( + (a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom); +} diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index edbd953de..cf5def4db 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -9,6 +9,78 @@ namespace mlx::core::cu { +/////////////////////////////////////////////////////////////////////////////// +// Unary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#else +#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x); \ + } else { \ + return ::NAME(x); \ + } \ + } +#endif + +#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \ + template \ + __forceinline__ __device__ auto NAME(T x) { \ + if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__half2float(x)); \ + } else if constexpr (cuda::std::is_same_v) { \ + return ::NAME(__bfloat162float(x)); \ + } else { \ + return ::NAME(x); \ + } \ + } + +MLX_DEFINE_UNARY_OP(abs, __habs) +MLX_DEFINE_UNARY_OP(ceil, hceil) +MLX_DEFINE_UNARY_OP(cos, hcos) +MLX_DEFINE_UNARY_OP(exp, hexp) +MLX_DEFINE_UNARY_OP(floor, hfloor) +MLX_DEFINE_UNARY_OP(isnan, __hisnan) +MLX_DEFINE_UNARY_OP(log, hlog) +MLX_DEFINE_UNARY_OP(log2, hlog2) +MLX_DEFINE_UNARY_OP(log10, hlog10) +MLX_DEFINE_UNARY_OP(rint, hrint) +MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt) +MLX_DEFINE_UNARY_OP(sin, hsin) +MLX_DEFINE_UNARY_OP(sqrt, hsqrt) +MLX_DEFINE_UNARY_OP_FALLBCK(acos) +MLX_DEFINE_UNARY_OP_FALLBCK(acosh) +MLX_DEFINE_UNARY_OP_FALLBCK(asin) +MLX_DEFINE_UNARY_OP_FALLBCK(asinh) +MLX_DEFINE_UNARY_OP_FALLBCK(atan) +MLX_DEFINE_UNARY_OP_FALLBCK(atanh) +MLX_DEFINE_UNARY_OP_FALLBCK(cosh) +MLX_DEFINE_UNARY_OP_FALLBCK(log1p) +MLX_DEFINE_UNARY_OP_FALLBCK(sinh) +MLX_DEFINE_UNARY_OP_FALLBCK(tan) +#if __CUDA_ARCH__ >= 1280 +MLX_DEFINE_UNARY_OP(tanh, htanh) +#else +MLX_DEFINE_UNARY_OP_FALLBCK(tanh) +#endif + +#undef MLX_DEFINE_UNARY_OP +#undef MLX_DEFINE_UNARY_OP_FALLBCK + /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/kernels/unary_ops.cuh b/mlx/backend/cuda/kernels/unary_ops.cuh new file mode 100644 index 000000000..6637a6eeb --- /dev/null +++ b/mlx/backend/cuda/kernels/unary_ops.cuh @@ -0,0 +1,349 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +struct Abs { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x; + } else if constexpr (cuda::std::is_same_v) { + return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0}; + } else { + return abs(x); + } + } +}; + +struct ArcCos { + template + __device__ T operator()(T x) { + return acos(x); + } +}; + +struct ArcCosh { + template + __device__ T operator()(T x) { + return acosh(x); + } +}; + +struct ArcSin { + template + __device__ T operator()(T x) { + return asin(x); + } +}; + +struct ArcSinh { + template + __device__ T operator()(T x) { + return asinh(x); + } +}; + +struct ArcTan { + template + __device__ T operator()(T x) { + return atan(x); + } +}; + +struct ArcTanh { + template + __device__ T operator()(T x) { + return atanh(x); + } +}; + +struct BitwiseInvert { + template + __device__ T operator()(T x) { + return ~x; + } +}; + +struct Ceil { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else { + return ceil(x); + } + } +}; + +struct Conjugate { + __device__ cuComplex operator()(cuComplex x) { + return {cuCrealf(x), -cuCimagf(x)}; + } +}; + +struct Cos { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + cos(cuCrealf(x)) * cosh(cuCimagf(x)), + -sin(cuCrealf(x)) * sinh(cuCimagf(x))}; + } else { + return cos(x); + } + } +}; + +struct Cosh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + cosh(cuCrealf(x)) * cos(cuCimagf(x)), + sinh(cuCrealf(x)) * sin(cuCimagf(x))}; + } else { + return cosh(x); + } + } +}; + +struct Erf { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erf(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erf(__bfloat162float(x)); + } else { + return erf(x); + } + } +}; + +struct ErfInv { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return erfinv(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return erfinv(__bfloat162float(x)); + } else { + return erfinv(x); + } + } +}; + +struct Exp { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + auto m = exp(cuCrealf(x)); + return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; + } else { + return exp(x); + } + } +}; + +struct Expm1 { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return expm1(__half2float(x)); + } else if constexpr (cuda::std::is_same_v) { + return expm1(__bfloat162float(x)); + } else { + return expm1(x); + } + } +}; + +struct Floor { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_integral_v) { + return x; + } else { + return floor(x); + } + } +}; + +struct Imag { + __device__ float operator()(cuComplex x) { + return cuCimagf(x); + } +}; + +struct Log { + template + __device__ T operator()(T x) { + return log(x); + } +}; + +struct Log2 { + template + __device__ T operator()(T x) { + return log2(x); + } +}; + +struct Log10 { + template + __device__ T operator()(T x) { + return log10(x); + } +}; + +struct Log1p { + template + __device__ T operator()(T x) { + return log1p(x); + } +}; + +struct LogicalNot { + __device__ bool operator()(bool x) { + return !x; + } +}; + +struct Negative { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return 0 - x; + } else { + return -x; + } + } +}; + +struct Real { + __device__ float operator()(cuComplex x) { + return cuCrealf(x); + } +}; + +struct Round { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return {rint(cuCrealf(x)), rint(cuCimagf(x))}; + } else { + return rint(x); + } + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + return rsqrt(x); + } +}; + +struct Sigmoid { + template + __device__ T operator()(T x) { + T y = 1 / (1 + exp(-abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x != 0; + } else if constexpr (cuda::std::is_same_v) { + if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { + return x; + } else { + return x / Abs()(x); + } + } else if constexpr (cuda::std::is_same_v) { + return static_cast((x > T(0.f)) - (x < T(0.f))); + } else { + return (x > T(0)) - (x < T(0)); + } + } +}; + +struct Sin { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + sin(cuCrealf(x)) * cosh(cuCimagf(x)), + cos(cuCrealf(x)) * sinh(cuCimagf(x))}; + } else { + return sin(x); + } + } +}; + +struct Sinh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + return { + sinh(cuCrealf(x)) * cos(cuCimagf(x)), + cosh(cuCrealf(x)) * sin(cuCimagf(x))}; + } else { + return sinh(x); + } + } +}; + +struct Square { + template + __device__ T operator()(T x) { + return x * x; + } +}; + +struct Sqrt { + template + __device__ T operator()(T x) { + return sqrt(x); + } +}; + +struct Tan { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + float tan_a = tan(cuCrealf(x)); + float tanh_b = tanh(cuCimagf(x)); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + } else { + return tan(x); + } + } +}; + +struct Tanh { + template + __device__ T operator()(T x) { + if constexpr (cuda::std::is_same_v) { + float tanh_a = tanh(cuCrealf(x)); + float tan_b = tan(cuCimagf(x)); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + } else { + return tanh(x); + } + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh new file mode 100644 index 000000000..4d69b7356 --- /dev/null +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -0,0 +1,43 @@ +// Copyright © 2025 Apple Inc. + +// This file must not include any host-only code, utilies that work under both +// host and device can be put here. +// +// See more about the requirements at: +// https://docs.nvidia.com/cuda/nvrtc/#language + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// CUDA kernel utils +/////////////////////////////////////////////////////////////////////////////// + +// To pass shape/strides to kernels via constant memory, their size must be +// known at compile time. +#define MAX_NDIM 8 + +using Shape = cuda::std::array; +using Strides = cuda::std::array; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +template +inline __host__ __device__ IdxT +elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index fad2d76d3..3d9186892 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,39 +71,22 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(Abs) NO_GPU(Add) -NO_GPU(ArcCos) -NO_GPU(ArcCosh) -NO_GPU(ArcSin) -NO_GPU(ArcSinh) -NO_GPU(ArcTan) NO_GPU(ArcTan2) -NO_GPU(ArcTanh) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) NO_GPU(BitwiseBinary) -NO_GPU(BitwiseInvert) NO_GPU(BlockMaskedMM) -NO_GPU(Ceil) NO_GPU_MULTI(Compiled) -NO_GPU(Conjugate) NO_GPU(Convolution) -NO_GPU(Cos) -NO_GPU(Cosh) NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(Remainder) NO_GPU(Equal) -NO_GPU(Erf) -NO_GPU(ErfInv) -NO_GPU(Exp) -NO_GPU(Expm1) NO_GPU(FFT) -NO_GPU(Floor) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) @@ -111,13 +94,9 @@ NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) -NO_GPU(Imag) NO_GPU(Less) NO_GPU(LessEqual) NO_GPU(Load) -NO_GPU(Log) -NO_GPU(Log1p) -NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) @@ -126,33 +105,22 @@ NO_GPU_MULTI(LUF) NO_GPU(Maximum) NO_GPU(Minimum) NO_GPU(Multiply) -NO_GPU(Negative) NO_GPU(NotEqual) NO_GPU(Partition) NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) -NO_GPU(Real) NO_GPU(Reduce) -NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) -NO_GPU(Sigmoid) -NO_GPU(Sign) -NO_GPU(Sin) -NO_GPU(Sinh) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) -NO_GPU(Square) -NO_GPU(Sqrt) NO_GPU(Subtract) NO_GPU_MULTI(SVD) -NO_GPU(Tan) -NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu new file mode 100644 index 000000000..0ee31ee28 --- /dev/null +++ b/mlx/backend/cuda/unary.cu @@ -0,0 +1,196 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernels/unary_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && !std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_unary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto policy = cu::thrust_policy(stream); + auto in_ptr = thrust::device_pointer_cast(in.data()); + auto out_ptr = thrust::device_pointer_cast(out.data()); + if (in.flags().contiguous) { + thrust::transform( + policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + } else { + auto [shape, strides] = collapse_contiguous_dims(in); + auto [in_begin, in_end] = cu::make_general_iterators( + in_ptr, in.data_size(), shape, strides); + thrust::transform(policy, in_begin, in_end, out_ptr, Op()); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string& op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) +UNARY_GPU(Ceil) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Floor) +UNARY_GPU(Imag) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Negative) +UNARY_GPU(Real) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) + +void Log::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Log::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, op, s); + break; + case Base::two: + unary_op_gpu(inputs, out, op, s); + break; + case Base::ten: + unary_op_gpu(inputs, out, op, s); + break; + } +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Round::eval_gpu"); + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, get_primitive_string(this), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 850c17376..0b118b72f 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/common/utils.h" + +#include "mlx/backend/common/unary.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -99,21 +100,7 @@ void unary_op_gpu( array& out, const std::string op, const Stream& s) { - auto& in = inputs[0]; - bool contig = in.flags().contiguous; - if (contig) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - allocator::malloc(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc(out.nbytes())); - } + set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } From 9ce77798b1588ee58d0366ad8e218327b663d24e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Jun 2025 20:37:27 -0700 Subject: [PATCH 080/156] fix export to work with gather/scatter axis (#2263) --- mlx/export.cpp | 2 ++ python/tests/test_export_import.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/mlx/export.cpp b/mlx/export.cpp index bd2f24ba2..552c35cfb 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -266,6 +266,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Floor), SERIALIZE_PRIMITIVE(Full), SERIALIZE_PRIMITIVE(Gather), + SERIALIZE_PRIMITIVE(GatherAxis), SERIALIZE_PRIMITIVE(GatherMM), SERIALIZE_PRIMITIVE(Greater), SERIALIZE_PRIMITIVE(GreaterEqual), @@ -307,6 +308,7 @@ struct PrimitiveFactory { "CumMax", "CumLogaddexp"), SERIALIZE_PRIMITIVE(Scatter), + SERIALIZE_PRIMITIVE(ScatterAxis), SERIALIZE_PRIMITIVE(Select), SERIALIZE_PRIMITIVE(Sigmoid), SERIALIZE_PRIMITIVE(Sign), diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 0190827bd..ef9827cbe 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -286,6 +286,32 @@ class TestExportImport(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): f2(mx.array(10), mx.array([5, 10, 20])) + def test_export_scatter_gather(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(a, b): + return mx.take_along_axis(a, b, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + mx.export_function(path, fun, (x, y)) + imported_fun = mx.import_function(path) + expected = fun(x, y) + out = imported_fun(x, y)[0] + self.assertTrue(mx.array_equal(expected, out)) + + def fun(a, b, c): + return mx.put_along_axis(a, b, c, axis=0) + + x = mx.random.uniform(shape=(4, 4)) + y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) + z = mx.random.uniform(shape=(2, 4)) + mx.export_function(path, fun, (x, y, z)) + imported_fun = mx.import_function(path) + expected = fun(x, y, z) + out = imported_fun(x, y, z)[0] + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main() From 7ebb2e01937ef8f2443aff731a9c9669621c4897 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 10 Jun 2025 22:37:40 +0900 Subject: [PATCH 081/156] CUDA backend: binary ops (#2259) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary.cu | 305 ++++++++++++++++++++++++ mlx/backend/cuda/kernel_utils.cuh | 62 +++++ mlx/backend/cuda/kernels/binary_ops.cuh | 278 +++++++++++++++++++++ mlx/backend/cuda/kernels/fp16_math.cuh | 46 ++++ mlx/backend/cuda/kernels/utils.cuh | 61 +++++ mlx/backend/cuda/primitives.cu | 19 -- 7 files changed, 753 insertions(+), 19 deletions(-) create mode 100644 mlx/backend/cuda/binary.cu create mode 100644 mlx/backend/cuda/kernels/binary_ops.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index cd73843bf..c813f8fd4 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu new file mode 100644 index 000000000..360772998 --- /dev/null +++ b/mlx/backend/cuda/binary.cu @@ -0,0 +1,305 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/binary_ops.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[0]); + } +} + +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[0], b[index]); + } +} + +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[0]); + } +} + +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index]); + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +template +constexpr bool supports_binary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && + (is_floating_v || std::is_same_v); + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > UINT32_MAX || + b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &cu::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, LARGE); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + std::vector outputs{out}; + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, get_primitive_string(this), s); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = outputs[0].primitive().stream(); \ + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Remainder) +BINARY_GPU(Equal) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, op, s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, op, s); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 6430b8c59..aeb065206 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -13,9 +13,40 @@ #include #include #include +#include namespace mlx::core { +// Convert a number between 1~3 to constexpr. +#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ + switch (N) { \ + case 1: { \ + constexpr int NDIM = 1; \ + __VA_ARGS__; \ + break; \ + } \ + case 2: { \ + constexpr int NDIM = 2; \ + __VA_ARGS__; \ + break; \ + } \ + case 3: { \ + constexpr int NDIM = 3; \ + __VA_ARGS__; \ + break; \ + } \ + } + +// Like MLX_SWITCH_ALL_TYPES but for booleans. +#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ + if (BOOL) { \ + constexpr bool BOOL_ALIAS = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool BOOL_ALIAS = false; \ + __VA_ARGS__; \ + } + // Maps CPU types to CUDA types. template struct CTypeToCudaType { @@ -66,4 +97,35 @@ dim3 get_2d_grid_dims( const Strides& strides, size_t divisor); +// Return a block size that achieves maximum potential occupancy for kernel. +template +inline uint max_occupancy_block_dim(T kernel) { + int _, block_dim; + CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + return block_dim; +} + +// Get the num_blocks and block_dims that maximize occupancy for |kernel|, +// assuming each thread handles |work_per_thread| elements of |arr|. +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + uint block_dim = max_occupancy_block_dim(kernel); + if (block_dim > nthreads) { + block_dim = nthreads; + } + dim3 num_blocks; + if (large) { + num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread); + num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); + } else { + num_blocks.x = cuda::ceil_div(nthreads, block_dim); + } + return std::make_tuple(num_blocks, block_dim); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/binary_ops.cuh b/mlx/backend/cuda/kernels/binary_ops.cuh new file mode 100644 index 000000000..3bc30eb02 --- /dev/null +++ b/mlx/backend/cuda/kernels/binary_ops.cuh @@ -0,0 +1,278 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + +#include +#include + +namespace mlx::core::cu { + +struct Add { + template + __device__ T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return x / y; + } else { + return trunc(x / y); + } + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + if constexpr (cuda::std::is_signed_v) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } else { + return x % y; + } + } else if constexpr (cuda::std::is_same_v) { + return x % y; + } else { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r = r + y; + } + return r; + } + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return x == y || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && + isnan(cuCimagf(y))) || + (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && + cuCimagf(x) == cuCimagf(y)); + } else { + return x == y || (isnan(x) && isnan(y)); + } + } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + __device__ T operator()(T x, T y) { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); + }; +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return max(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x > y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x > y ? x : y; + } + } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { + if constexpr (cuda::std::is_integral_v) { + return min(x, y); + } else if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + return x; + } + return x < y ? x : y; + } else { + if (isnan(x)) { + return x; + } + return x < y ? x : y; + } + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); + } else { + return x != y; + } + } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + if constexpr (cuda::std::is_integral_v) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } else if constexpr (cuda::std::is_same_v) { + auto x_theta = atan2f(base.y, base.x); + auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); + auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); + auto phase = exp.y * x_ln_r + exp.x * x_theta; + return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase)); + } else { + return powf(base, exp); + } + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + __device__ T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + __device__ T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + return atan2f(y, x); + } +}; + +struct DivMod { + template + __device__ cuda::std::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index cf5def4db..f6fa17bb9 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -81,6 +81,52 @@ MLX_DEFINE_UNARY_OP_FALLBCK(tanh) #undef MLX_DEFINE_UNARY_OP #undef MLX_DEFINE_UNARY_OP_FALLBCK +/////////////////////////////////////////////////////////////////////////////// +// Binary ops for half types. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#else +#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \ + template \ + __forceinline__ __device__ auto NAME(T x, T y) { \ + if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else if constexpr (cuda::std::is_same_v) { \ + return HALF_OP(x, y); \ + } else { \ + return ::NAME(x, y); \ + } \ + } +#endif + +MLX_DEFINE_BINARY_OP(max, __hmax) +MLX_DEFINE_BINARY_OP(min, __hmin) + +#undef MLX_DEFINE_BINARY_OP + +template +__forceinline__ __device__ T fmod(T x, T y) { + if constexpr (cuda::std::is_same_v) { + return __float2half(::fmod(__half2float(x), __half2float(y))); +#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800 + } else if constexpr (cuda::std::is_same_v) { + return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y))); +#endif + } else { + return ::fmod(x, y); + } +} + /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 4d69b7356..16957d132 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -11,6 +11,7 @@ #include #include #include +#include namespace mlx::core::cu { @@ -40,4 +41,64 @@ elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { return loc; } +// Optimize when the ndim is known at compile time. +template +inline __host__ __device__ IdxT +elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { + IdxT loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + +// Optimized version when ndim is larger than 4. +template +inline __host__ __device__ IdxT +elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = elem_to_loc_nd<3>(elem, shape, strides); + for (int i = ndim - 1; i >= 3; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides); + for (int i = ndim - 1; i >= 3; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc); +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 3d9186892..2c3a73c42 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,43 +71,25 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(Add) -NO_GPU(ArcTan2) NO_GPU(ArgPartition) NO_GPU(ArgReduce) NO_GPU(ArgSort) -NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) -NO_GPU(Divide) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) -NO_GPU(Remainder) -NO_GPU(Equal) NO_GPU(FFT) NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) -NO_GPU(Greater) -NO_GPU(GreaterEqual) NO_GPU(Hadamard) -NO_GPU(Less) -NO_GPU(LessEqual) NO_GPU(Load) -NO_GPU(LogicalAnd) -NO_GPU(LogicalOr) -NO_GPU(LogAddExp) NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) -NO_GPU(Maximum) -NO_GPU(Minimum) -NO_GPU(Multiply) -NO_GPU(NotEqual) NO_GPU(Partition) -NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) @@ -119,7 +101,6 @@ NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) -NO_GPU(Subtract) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) From 004c1d8ef2fcdc04f29acbc497c9fea7c190591a Mon Sep 17 00:00:00 2001 From: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:37:50 +0100 Subject: [PATCH 082/156] Report number of missing parameters (#2264) * chore: inform * chore: format --------- Co-authored-by: FL33TW00D --- python/mlx/nn/layers/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 783ef446d..af639dc4e 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -174,11 +174,15 @@ class Module(dict): new_weights = dict(weights) curr_weights = dict(tree_flatten(self.parameters())) if extras := (new_weights.keys() - curr_weights.keys()): - extras = " ".join(extras) - raise ValueError(f"Received parameters not in model: {extras}.") + num_extra = len(extras) + extras = ",\n".join(sorted(extras)) + raise ValueError( + f"Received {num_extra} parameters not in model: \n{extras}." + ) if missing := (curr_weights.keys() - new_weights.keys()): - missing = " ".join(missing) - raise ValueError(f"Missing parameters: {missing}.") + num_missing = len(missing) + missing = ",\n".join(sorted(missing)) + raise ValueError(f"Missing {num_missing} parameters: \n{missing}.") for k, v in curr_weights.items(): v_new = new_weights[k] if not isinstance(v_new, mx.array): From bae9a6b404aa21fa068faab626839051f9d610fe Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 00:59:47 +0900 Subject: [PATCH 083/156] CUDA backend: sort (#2262) Co-authored-by: Awni Hannun --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 2 - mlx/backend/cuda/sort.cu | 180 ++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/sort.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c813f8fd4..23ae64cf6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,6 +16,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 2c3a73c42..3f3674c07 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback( NO_GPU(ArgPartition) NO_GPU(ArgReduce) -NO_GPU(ArgSort) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) @@ -100,7 +99,6 @@ NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) -NO_GPU(Sort) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu new file mode 100644 index 000000000..e1c2e8530 --- /dev/null +++ b/mlx/backend/cuda/sort.cu @@ -0,0 +1,180 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +template +struct ModOp { + T divisor; + __device__ T operator()(T x) { + return x % divisor; + } +}; + +// We can not use any op in eval, make an utility. +array swapaxes_in_eval(const array& in, int axis1, int axis2) { + std::vector axes(in.ndim()); + std::iota(axes.begin(), axes.end(), 0); + std::swap(axes[axis1], axes[axis2]); + // TODO: Share the code with Transpose::eval. + Shape shape(axes.size()); + Strides strides(in.ndim()); + for (size_t ax = 0; ax < axes.size(); ++ax) { + shape[ax] = in.shape()[axes[ax]]; + strides[ax] = in.strides()[axes[ax]]; + } + auto flags = in.flags(); + if (flags.contiguous) { + auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides); + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + } + array out(shape, in.dtype(), nullptr, {}); + out.copy_shared_buffer(in, strides, flags, in.data_size()); + return out; +} + +template +void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR( + cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), size, args...)); +} + +template +void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR( + cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), size, args...)); +} + +void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { + array out = out_; + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + if (axis < 0) { + axis += in.ndim(); + } + int nsort = in.shape(axis); + int nsegments = in.data_size() / nsort; + int last_dim = in.ndim() - 1; + + // If we are not sorting the innermost dimension of a contiguous array, + // transpose and make a copy. + bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; + if (!is_segmented_sort) { + array trans = swapaxes_in_eval(in, axis, last_dim); + in = array(trans.shape(), trans.dtype(), nullptr, {}); + copy_gpu(trans, in, CopyType::General, s); + encoder.add_temporary(in); + out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(out); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + [nsort] __device__(int i) { return i * nsort; }); + if (argsort) { + // Indices in the sorted dimension. + array indices( + allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); + + segmented_sort_pairs( + encoder, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } else { + segmented_sort( + encoder, + in.data(), + out.data(), + in.data_size(), + nsegments, + offsets, + offsets + 1, + stream); + } + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); + }); + + if (!is_segmented_sort) { + // Swap the sorted axis back. + // TODO: Do in-place transpose instead of using a temporary out array. + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgSort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sort::eval_gpu"); + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +} // namespace mlx::core From 7c4eb5d03e8ebfa2d7531e92db322e1b201f9127 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 00:59:56 +0900 Subject: [PATCH 084/156] CUDA backend: random (#2261) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/random.cu | 181 ++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/random.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 23ae64cf6..6ca176ceb 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 3f3674c07..caa2c33ff 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(RandomBits) NO_GPU(Reduce) NO_GPU(Scan) NO_GPU(Scatter) diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu new file mode 100644 index 000000000..d2b1b7dd5 --- /dev/null +++ b/mlx/backend/cuda/random.cu @@ -0,0 +1,181 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace cu { + +__constant__ constexpr uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24}}; + +union rbits { + uint2 val; + uint8_t bytes[2][4]; +}; + +__device__ rbits threefry2x32_hash(uint2 key, uint2 count) { + uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +__global__ void rbitsc( + const uint32_t* keys, + uint8_t* out, + dim3 grid_dims, + bool odd, + uint32_t bytes_per_key) { + uint2 index{ + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y}; + if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + return; + } + + auto kidx = 2 * index.x; + auto key = uint2{keys[kidx], keys[kidx + 1]}; + auto half_size = grid_dims.y - odd; + out += index.x * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +__global__ void rbits( + const uint32_t* keys, + uint8_t* out, + dim3 grid_dims, + bool odd, + uint32_t bytes_per_key, + int32_t ndim, + const __grid_constant__ Shape key_shape, + const __grid_constant__ Strides key_strides) { + uint2 index{ + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y}; + if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + return; + } + + auto kidx = 2 * index.x; + auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); + auto k2_elem = + elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); + auto key = uint2{keys[k1_elem], keys[k2_elem]}; + auto half_size = grid_dims.y - odd; + out += size_t(index.x) * bytes_per_key; + bool drop_last = odd && (index.y == half_size); + auto bits = threefry2x32_hash( + key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); + size_t idx = size_t(index.y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; + if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +} // namespace cu + +void RandomBits::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("RandomBits::eval_gpu"); + assert(inputs.size() == 1); + + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + uint32_t num_keys = keys.size() / 2; + + uint32_t elems_per_key = out.size() / num_keys; + uint32_t bytes_per_key = out.itemsize() * elems_per_key; + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; + uint32_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(keys); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + dim3 grid_dims{num_keys, half_size + odd}; + dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1); + dim3 num_blocks{ + cuda::ceil_div(grid_dims.x, block_dims.x), + cuda::ceil_div(grid_dims.y, block_dims.y)}; + if (keys.flags().row_contiguous) { + cu::rbitsc<<>>( + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key); + } else { + cu::rbits<<>>( + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } + }); +} + +} // namespace mlx::core From 62fecf3e13c8f12ee532b56079b073d2afcfa9ef Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 10 Jun 2025 09:34:01 -0700 Subject: [PATCH 085/156] fix conv export (#2265) --- mlx/primitives.h | 2 +- python/tests/test_export_import.py | 34 ++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/mlx/primitives.h b/mlx/primitives.h index cc60bcfb9..4b18430ca 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -719,9 +719,9 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( + kernel_strides_, padding_lo_, padding_hi_, - kernel_strides_, kernel_dilation_, input_dilation_, groups_, diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index ef9827cbe..0fd8bfd87 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -6,6 +6,7 @@ import tempfile import unittest import mlx.core as mx +import mlx.nn as nn import mlx_tests @@ -312,6 +313,39 @@ class TestExportImport(mlx_tests.MLXTestCase): out = imported_fun(x, y, z)[0] self.assertTrue(mx.array_equal(expected, out)) + def test_export_conv(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d( + 3, 16, kernel_size=3, stride=1, padding=1, bias=False + ) + self.c2 = nn.Conv2d( + 16, 16, kernel_size=3, stride=2, padding=1, bias=False + ) + self.c3 = nn.Conv2d( + 16, 16, kernel_size=3, stride=1, padding=2, bias=False + ) + + def __call__(self, x): + return self.c3(self.c2(self.c1(x))) + + model = Model() + mx.eval(model.parameters()) + + def forward(x): + return model(x) + + input_data = mx.random.normal(shape=(4, 32, 32, 3)) + mx.export_function(path, forward, input_data) + + imported_fn = mx.import_function(path) + out = imported_fn(input_data)[0] + expected = forward(input_data) + self.assertTrue(mx.allclose(expected, out)) + if __name__ == "__main__": unittest.main() From 99c33d011d63174f50cea37c3eede002958be6d3 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 02:51:51 +0900 Subject: [PATCH 086/156] rebase + nit (#2260) Co-authored-by: Awni Hannun --- mlx/backend/cuda/CMakeLists.txt | 15 ++- mlx/backend/cuda/copy.cpp | 26 ----- mlx/backend/cuda/copy.cu | 89 +++++++++++++++ mlx/backend/cuda/copy/copy.cuh | 71 ++++++++++++ mlx/backend/cuda/copy/copy_contiguous.cu | 56 ++++++++++ mlx/backend/cuda/copy/copy_general.cu | 95 ++++++++++++++++ mlx/backend/cuda/copy/copy_general_dynamic.cu | 105 ++++++++++++++++++ mlx/backend/cuda/copy/copy_general_input.cu | 88 +++++++++++++++ mlx/backend/cuda/kernels/cast_op.cuh | 59 ++++++++++ mlx/backend/cuda/slicing.cpp | 28 ++++- 10 files changed, 604 insertions(+), 28 deletions(-) delete mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/copy.cu create mode 100644 mlx/backend/cuda/copy/copy.cuh create mode 100644 mlx/backend/cuda/copy/copy_contiguous.cu create mode 100644 mlx/backend/cuda/copy/copy_general.cu create mode 100644 mlx/backend/cuda/copy/copy_general_dynamic.cu create mode 100644 mlx/backend/cuda/copy/copy_general_input.cu create mode 100644 mlx/backend/cuda/kernels/cast_op.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 6ca176ceb..7ffbcb2d3 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -7,7 +7,11 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu @@ -28,6 +32,15 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. +# Explicitly pass this flag to suppress the warning, it is safe to set it to +# true but the warning wouldn't be suppressed. +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + target_compile_options( + mlx + PRIVATE "$<$:--static-global-template-stub=false>") +endif() + # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp deleted file mode 100644 index d0413d989..000000000 --- a/mlx/backend/cuda/copy.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/gpu/copy.h" - -namespace mlx::core { - -void copy_gpu_inplace( - const array& in, - array& out, - const Shape& data_shape, - const Strides& strides_in_pre, - const Strides& strides_out_pre, - int64_t inp_offset, - int64_t out_offset, - CopyType ctype, - const Stream& s, - const std::optional& dynamic_i_offset /* = std::nullopt */, - const std::optional& dynamic_o_offset /* = std::nullopt */) { - throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); -} - -void fill_gpu(const array& val, array& out, const Stream& s) { - throw std::runtime_error("fill_gpu not implemented in CUDA backend."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu new file mode 100644 index 000000000..8649e1bf9 --- /dev/null +++ b/mlx/backend/cuda/copy.cu @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/copy/copy.cuh" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in_, + array& out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + int64_t offset_in, + int64_t offset_out, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_offset_in, + const std::optional& dynamic_offset_out) { + if (out.size() == 0) { + return; + } + const array& in = in_.data_shared_ptr() ? in_ : out; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { + copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); + return; + } + + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, std::vector{strides_in, strides_out}, INT32_MAX); + if (ctype == CopyType::General) { + copy_general_input( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0]); + } else { + if (dynamic_offset_in || dynamic_offset_out) { + copy_general_dynamic( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1], + dynamic_offset_in ? *dynamic_offset_in : array(0, int64), + dynamic_offset_out ? *dynamic_offset_out : array(0, int64)); + } else { + copy_general( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1]); + } + } + return; + } +} + +void fill_gpu(const array& in, array& out, const Stream& s) { + if (out.size() == 0) { + return; + } + out.set_data(allocator::malloc(out.nbytes())); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh new file mode 100644 index 000000000..dd1d09d30 --- /dev/null +++ b/mlx/backend/cuda/copy/copy.cuh @@ -0,0 +1,71 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ + using InType = cuda_type_t; \ + using OutType = cuda_type_t; \ + if constexpr (cu::CastOp::is_castable) { \ + __VA_ARGS__; \ + } else { \ + throw std::runtime_error(fmt::format( \ + "Can not copy data from dtype {} to {}.", \ + dtype_to_string(out.dtype()), \ + dtype_to_string(in.dtype()))); \ + } \ + }); \ + }) + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out); + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out); + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out); + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu new file mode 100644 index 000000000..fa79f0604 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -0,0 +1,56 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_s(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = CastOp{}(in[0]); + } +} + +template +__global__ void copy_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = CastOp{}(in[index]); + } +} + +} // namespace cu + +void copy_contiguous( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + kernel<<>>( + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu new file mode 100644 index 000000000..3c5b3bbb3 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -0,0 +1,95 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_gg( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu new file mode 100644 index 000000000..b9774662a --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -0,0 +1,105 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_gg_dynamic_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in, + const __grid_constant__ cuda::std::array strides_out, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_nd( + index, shape.data(), strides_in.data(), strides_out.data()); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +template +__global__ void copy_gg_dynamic( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides_out, + int ndim, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [idx_in, idx_out] = elem_to_loc_4d( + index, shape.data(), strides_in.data(), strides_out.data(), ndim); + out[idx_out + *offset_out] = CastOp{}(in[idx_in + *offset_in]); + } +} + +} // namespace cu + +void copy_general_dynamic( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_dynamic_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu new file mode 100644 index 000000000..4f2784927 --- /dev/null +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -0,0 +1,88 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/copy/copy.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void copy_g_nd( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array strides_in) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_nd(index, shape.data(), strides_in.data()); + out[index] = CastOp{}(in[idx_in]); + } +} + +template +__global__ void copy_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides_in, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); + out[index] = CastOp{}(in[idx_in]); + } +} + +} // namespace cu + +void copy_general_input( + cu::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_g_nd; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.data_size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/cast_op.cuh b/mlx/backend/cuda/kernels/cast_op.cuh new file mode 100644 index 000000000..30b44d46f --- /dev/null +++ b/mlx/backend/cuda/kernels/cast_op.cuh @@ -0,0 +1,59 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// An op that does static_cast, with custom conversions for some types. +template +struct CastOp { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(SrcT x) { + return static_cast(x); + } +}; + +// Converting a complex number to real number discards the imaginary part. +template +struct CastOp< + cuComplex, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ DstT operator()(cuComplex x) { + static_assert(!cuda::std::is_same_v); + return static_cast(cuCrealf(x)); + } +}; + +// Allow converting a real number to complex number. +template +struct CastOp< + SrcT, + cuComplex, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; + + __device__ cuComplex operator()(SrcT x) { + static_assert(!cuda::std::is_same_v); + return cuComplex{static_cast(x), 0}; + } +}; + +// Return an iterator that cast the value to DstT using CastOp. +template +__host__ __device__ auto make_cast_iterator(Iterator it) { + using SrcT = typename cuda::std::iterator_traits::value_type; + if constexpr (std::is_same_v) { + return it; + } else { + return thrust::make_transform_iterator(it, CastOp{}); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index bfa742c74..af67fbbdd 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -1,7 +1,11 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#include + namespace mlx::core { void concatenate_gpu( @@ -9,7 +13,29 @@ void concatenate_gpu( array& out, int axis, const Stream& s) { - throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + // TODO: Handle concurrent outputs: + // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } } } // namespace mlx::core From 095163b8d1d58cec181b8de94c5f7805659fe718 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 11 Jun 2025 09:10:24 +0900 Subject: [PATCH 087/156] Fix building cpp benchmarks on Linux (#2268) --- benchmarks/cpp/irregular_strides.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/cpp/irregular_strides.cpp b/benchmarks/cpp/irregular_strides.cpp index cda76fed6..552461335 100644 --- a/benchmarks/cpp/irregular_strides.cpp +++ b/benchmarks/cpp/irregular_strides.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include #include From 8590c0941e5c034b56dea3b33efa108668de540c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 10 Jun 2025 20:58:16 -0700 Subject: [PATCH 088/156] Add load_safe to the general conv loaders (#2258) --- benchmarks/python/conv_unaligned_bench.py | 107 ++++++++++++++++++ mlx/backend/metal/conv.cpp | 36 ++++-- mlx/backend/metal/jit_kernels.cpp | 4 +- mlx/backend/metal/kernels.h | 2 + .../steel/conv/kernels/steel_conv_general.h | 63 ++++++++--- .../steel/conv/loaders/loader_general.h | 95 ++++++++++++++++ mlx/backend/metal/nojit_kernels.cpp | 4 +- python/tests/test_conv.py | 13 +++ 8 files changed, 302 insertions(+), 22 deletions(-) create mode 100644 benchmarks/python/conv_unaligned_bench.py diff --git a/benchmarks/python/conv_unaligned_bench.py b/benchmarks/python/conv_unaligned_bench.py new file mode 100644 index 000000000..981d7b48b --- /dev/null +++ b/benchmarks/python/conv_unaligned_bench.py @@ -0,0 +1,107 @@ +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_2D + + +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_2D(strides, padding, groups) + f_pt = make_pt_conv_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + dtype = "float32" + shapes = ( + (4, 32, 32, 21, 3, 3, 128), + (4, 32, 32, 21, 3, 3, 37), + (4, 32, 32, 370, 3, 3, 370), + (4, 32, 32, 370, 7, 7, 128), + (2, 320, 640, 21, 7, 7, 21), + ) + for N, H, W, C, kh, kw, O in shapes: + time_mlx, time_torch = bench_shape( + N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 593b79384..697afa6a1 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu( // Get channel iteration info int channel_k_iters = ((conv_params.C + bk - 1) / bk); int gemm_k_iters = channel_k_iters; + bool align_C = conv_params.C % bk == 0; // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); @@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu( /* const int swizzle_log = */ swizzle_log}; // Determine kernel - std::ostringstream kname; - kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm - << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + std::string kname; + kname.reserve(64); + concatenate( + kname, + "implicit_gemm_conv_2d_general_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + std::string hash_name; + hash_name.reserve(64); + concatenate(hash_name, kname, "_alC_", align_C); + metal::MTLFCList func_consts = { + {&align_C, MTL::DataType::DataTypeBool, 200}, + }; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = - get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); + auto kernel = get_steel_conv_general_kernel( + d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions @@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu( // Direct to winograd conv bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; bool channels_large = (conv_params.C + conv_params.O) >= 256; + bool out_large = + (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && @@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu( return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) { return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 15e21af6c..467380c3a 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, @@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( wn); return kernel_source.str(); }); - return d.get_kernel(kernel_name, lib); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 6d8864385..1de5fa47c 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 8253638f1..9afebd307 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -2,6 +2,8 @@ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" +constant bool align_C [[function_constant(200)]]; + template < typename T, int BM, @@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general( // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h index 72335e698..9b7ddc2ee 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; @@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b1478d33b..b0375e37f 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array&, int, int, int, int, int) { - return d.get_kernel(kernel_name); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 9fe11286d..c68315a5d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + def test_conv2d_unaligned_channels(self): + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(32, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(21, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + if __name__ == "__main__": unittest.main() From c35f4d089abac69de5e337aa041887d767f6553a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 10 Jun 2025 21:19:47 -0700 Subject: [PATCH 089/156] start cuda circle config (#2256) * rebase * fix metal kernel linking issue on cuda * start cuda circle config --- .circleci/config.yml | 26 ++++++++++++++++++++++++++ mlx/CMakeLists.txt | 3 +++ mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/cuda.cpp | 11 +++++++++++ mlx/backend/cuda/cuda.h | 10 ++++++++++ mlx/backend/cuda/no_cuda.cpp | 11 +++++++++++ mlx/backend/metal/no_metal.cpp | 24 ++++++++++++++++++++++-- mlx/backend/no_gpu/primitives.cpp | 13 ------------- mlx/mlx.h | 1 + python/src/array.cpp | 10 +++++----- python/src/device.cpp | 5 +++++ python/tests/test_array.py | 2 +- python/tests/test_device.py | 8 ++++---- python/tests/test_optimizers.py | 2 +- 14 files changed, 101 insertions(+), 26 deletions(-) create mode 100644 mlx/backend/cuda/cuda.cpp create mode 100644 mlx/backend/cuda/cuda.h create mode 100644 mlx/backend/cuda/no_cuda.cpp diff --git a/.circleci/config.yml b/.circleci/config.yml index 6dc7ec4df..808242f9b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -212,6 +212,29 @@ jobs: METAL_DEBUG_ERROR_MODE=0 \ python -m xmlrunner discover -v python/tests -o test-results/gpu_jit + cuda_build_and_test: + machine: + image: linux-cuda-12:default + resource_class: gpu.nvidia.small.gen2 + steps: + - checkout + - run: + name: Install Python package + command: | + sudo apt-get update + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev + python -m venv env + source env/bin/activate + CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + pip install -e ".[dev]" + - run: + name: Run Python tests + command: | + source env/bin/activate + LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v + build_release: parameters: python_version: @@ -348,6 +371,7 @@ workflows: parameters: macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test + - cuda_build_and_test - build_documentation build_pypi_release: @@ -455,6 +479,8 @@ workflows: macosx_deployment_target: ["13.5", "14.0"] - linux_build_and_test: requires: [ hold ] + - cuda_build_and_test: + requires: [ hold ] nightly_build: when: and: diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index ce921b276..7aa648533 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -55,6 +55,9 @@ endif() if(MLX_BUILD_CUDA) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7ffbcb2d3..9d9657e1f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/cuda.cpp b/mlx/backend/cuda/cuda.cpp new file mode 100644 index 000000000..ceb4d7dfe --- /dev/null +++ b/mlx/backend/cuda/cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return true; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/cuda.h b/mlx/backend/cuda/cuda.h new file mode 100644 index 000000000..2c6a5c724 --- /dev/null +++ b/mlx/backend/cuda/cuda.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cu { + +/* Check if the CUDA backend is available. */ +bool is_available(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp new file mode 100644 index 000000000..8a394c9e3 --- /dev/null +++ b/mlx/backend/cuda/no_cuda.cpp @@ -0,0 +1,11 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cuda.h" + +namespace mlx::core::cu { + +bool is_available() { + return false; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index b6142b280..9785e07c2 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -3,8 +3,11 @@ #include #include "mlx/backend/metal/metal.h" +#include "mlx/fast.h" -namespace mlx::core::metal { +namespace mlx::core { + +namespace metal { bool is_available() { return false; @@ -19,4 +22,21 @@ device_info() { "[metal::device_info] Cannot get device info without metal backend"); }; -} // namespace mlx::core::metal +} // namespace metal + +namespace fast { + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 849cbf83e..409aa2c89 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -2,7 +2,6 @@ #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" -#include "mlx/fast.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) - -MetalKernelFunction metal_kernel( - const std::string&, - const std::vector&, - const std::vector&, - const std::string&, - const std::string&, - bool ensure_row_contiguous, - bool atomic_outputs) { - throw std::runtime_error("[metal_kernel] No GPU back-end."); -} - } // namespace fast namespace distributed { diff --git a/mlx/mlx.h b/mlx/mlx.h index cef8d806d..de3ee392a 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" diff --git a/python/src/array.cpp b/python/src/array.cpp index 5ba0aaedc..25889d775 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -17,10 +17,7 @@ #include "python/src/indexing.h" #include "python/src/utils.h" -#include "mlx/device.h" -#include "mlx/ops.h" -#include "mlx/transforms.h" -#include "mlx/utils.h" +#include "mlx/mlx.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -461,9 +458,12 @@ void init_array(nb::module_& m) { .def( "__dlpack_device__", [](const mx::array& a) { + // See + // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74 if (mx::metal::is_available()) { - // Metal device is available return nb::make_tuple(8, 0); + } else if (mx::cu::is_available()) { + return nb::make_tuple(13, 0); } else { // CPU device return nb::make_tuple(1, 0); diff --git a/python/src/device.cpp b/python/src/device.cpp index 85b15dd4d..006a05dc0 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -58,4 +58,9 @@ void init_device(nb::module_& m) { &mx::set_default_device, "device"_a, R"pbdoc(Set the default device.)pbdoc"); + m.def( + "is_available", + &mx::is_available, + "device"_a, + R"pbdoc(Check if a back-end is available for the given device.)pbdoc"); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index e63da17df..c22e0a38f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase): def test_dlx_device_type(self): a = mx.array([1, 2, 3]) device_type, device_id = a.__dlpack_device__() - self.assertIn(device_type, [1, 8]) + self.assertIn(device_type, [1, 8, 13]) self.assertEqual(device_id, 0) if device_type == 8: diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 53826cad7..6793c98d1 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -10,7 +10,7 @@ import mlx_tests class TestDefaultDevice(unittest.TestCase): def test_mlx_default_device(self): device = mx.default_device() - if mx.metal.is_available(): + if mx.is_available(mx.gpu): self.assertEqual(device, mx.Device(mx.gpu)) self.assertEqual(str(device), "Device(gpu, 0)") self.assertEqual(device, mx.gpu) @@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase): self.assertEqual(s2.device, mx.default_device()) self.assertNotEqual(s1, s2) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.default_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase): s_cpu = mx.new_stream(mx.cpu) self.assertEqual(s_cpu.device, mx.cpu) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): s_gpu = mx.new_stream(mx.gpu) self.assertEqual(s_gpu.device, mx.gpu) else: @@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase): a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) - if mx.metal.is_available(): + if mx.is_available(mx.gpu): b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) self.assertEqual(a.item(), b.item()) s_gpu = mx.new_stream(mx.gpu) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index ebfe97d80..4943fe662 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) -class TestSchedulers(unittest.TestCase): +class TestSchedulers(mlx_tests.MLXTestCase): def test_decay_lr(self): for optim_class in optimizers_dict.values(): lr_schedule = opt.step_decay(1e-1, 0.9, 1) From c9fa68664a36ca6d8f071fe2155fd136775bc87e Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 03:22:25 +0900 Subject: [PATCH 090/156] CUDA backend: reduce (#2269) --- mlx/backend/cuda/CMakeLists.txt | 4 + mlx/backend/cuda/kernel_utils.cuh | 25 ++ mlx/backend/cuda/kernels/utils.cuh | 198 ++++++++++++++ mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/reduce.cu | 82 ++++++ mlx/backend/cuda/reduce/col_reduce.cu | 278 ++++++++++++++++++++ mlx/backend/cuda/reduce/reduce.cuh | 74 ++++++ mlx/backend/cuda/reduce/reduce_ops.cuh | 144 ++++++++++ mlx/backend/cuda/reduce/row_reduce.cu | 250 ++++++++++++++++++ mlx/backend/cuda/reduce/segmented_reduce.cu | 84 ++++++ 10 files changed, 1139 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/reduce.cu create mode 100644 mlx/backend/cuda/reduce/col_reduce.cu create mode 100644 mlx/backend/cuda/reduce/reduce.cuh create mode 100644 mlx/backend/cuda/reduce/reduce_ops.cuh create mode 100644 mlx/backend/cuda/reduce/row_reduce.cu create mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9d9657e1f..c053b4428 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -21,6 +21,10 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index aeb065206..656ddebea 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -47,6 +47,31 @@ namespace mlx::core { __VA_ARGS__; \ } +// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. +#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ + { \ + uint32_t _num_threads = NUM_THREADS; \ + if (_num_threads <= WARP_SIZE) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 2) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 4) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 8) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \ + __VA_ARGS__; \ + } else if (_num_threads <= WARP_SIZE * 16) { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \ + __VA_ARGS__; \ + } \ + } + // Maps CPU types to CUDA types. template struct CTypeToCudaType { diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 16957d132..7636710dc 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -9,6 +9,8 @@ #pragma once #include +#include +#include #include #include #include @@ -19,6 +21,10 @@ namespace mlx::core::cu { // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 + // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. #define MAX_NDIM 8 @@ -26,6 +32,94 @@ namespace mlx::core::cu { using Shape = cuda::std::array; using Strides = cuda::std::array; +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T min() { + return cuda::std::numeric_limits::min(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::min(); + } +}; + +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { + return -cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { + return cuda::std::numeric_limits::lowest(); + } +}; + +// CUDA 11 does not have host side arithmatic operators for half types. +template +struct Limits< + T, + cuda::std::enable_if_t< + cuda::std::is_same_v || + cuda::std::is_same_v>> { + static constexpr __host__ __device__ T max() { + return cuda::std::numeric_limits::infinity(); + } + static constexpr __host__ __device__ T min() { +#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 + return -cuda::std::numeric_limits::infinity(); +#else + return -cuda::std::numeric_limits::infinity(); +#endif + } + static constexpr __host__ __device__ T finite_max() { + return cuda::std::numeric_limits::max(); + } + static constexpr __host__ __device__ T finite_min() { +#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 + return cuda::std::numeric_limits::lowest(); +#else + return cuda::std::numeric_limits::lowest(); +#endif + } +}; + +template <> +struct Limits { + static constexpr __host__ __device__ bool max() { + return true; + } + static constexpr __host__ __device__ bool min() { + return false; + } +}; + +template <> +struct Limits { + static constexpr __host__ __device__ cuComplex max() { + return {Limits::max(), Limits::max()}; + } + static constexpr __host__ __device__ cuComplex min() { + return {Limits::min(), Limits::min()}; + } +}; + /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// @@ -101,4 +195,108 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( return cuda::std::make_tuple(a_loc, b_loc); } +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + __device__ void next(const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, true, OffsetT> { + int dim; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim) {} + + __device__ void next(const int* shape, const int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, false, OffsetT> { + OffsetT offset{0}; + + __device__ LoopedElemToLoc(int) {} + + __device__ void next(const int*, const int64_t* strides) { + offset += OffsetT(strides[0]); + } + + __device__ void next(int n, const int*, const int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + __device__ OffsetT location() { + return offset; + } +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index caa2c33ff..1b273e959 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(Reduce) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu new file mode 100644 index 000000000..a740113db --- /dev/null +++ b/mlx/backend/cuda/reduce.cu @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Reduce::eval_gpu"); + assert(inputs.size() == 1); + array in = inputs[0]; + + // Make sure no identity reductions trickle down here. + assert(!axes_.empty()); + assert(out.size() != in.size()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Fill out with init value. + if (in.size() == 0) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + thrust::fill_n( + cu::thrust_policy(stream), + thrust::device_pointer_cast(out.data()), + out.data_size(), + cu::ReduceInit::value()); + }); + }); + }); + return; + } + + // Reduce. + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + if ((plan.type == ContiguousAllReduce) || + (plan.type == ContiguousReduce && plan.shape.size() == 1)) { + segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + col_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + throw std::runtime_error("No plan reached in reduce."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu new file mode 100644 index 000000000..1ca50d854 --- /dev/null +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -0,0 +1,278 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + int column = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + if (column * N_READS >= args.reduction_stride) { + return; + } + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next( + block.thread_index().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + for (size_t r = block.thread_index().y; + r < args.non_col_reductions * args.reduction_size; + r += block.dim_threads().y) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location()), + vals, + args.reduction_stride, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next( + block.dim_threads().y, + args.reduce_shape.data(), + args.reduce_strides.data()); + } + + // Do block reduce when each column has more than 1 element to reduce. + if (block.dim_threads().y > 1) { + __shared__ U shared_vals[32 * 8 * N_READS]; + size_t col = + block.thread_index().y * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + shared_vals[col * N_READS + i] = totals[i]; + } + block.sync(); + if (block.thread_index().y == 0) { + for (int i = 0; i < N_READS; i++) { + totals[i] = shared_vals[block.thread_index().x * N_READS + i]; + } + for (int j = 1; j < block.dim_threads().y; j++) { + col = j * block.dim_threads().x + block.thread_index().x; + for (int i = 0; i < N_READS; i++) { + totals[i] = op(shared_vals[col * N_READS + i], totals[i]); + } + } + } + } + + // Write result. + if (block.thread_index().y == 0) { + cub::StoreDirectBlocked( + column, + out + out_idx * args.reduction_stride, + totals, + args.reduction_stride); + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4> +__global__ void col_reduce_looped( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int n_warps = BN / N_READS; + + int out_idx = grid.block_rank() / grid.dim_blocks().x; + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + // Read input to local. + int r = block.thread_rank() / n_warps; + int column = block.thread_rank() % n_warps; + int in_offset = grid.block_index().x * BN; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); + for (; r < args.non_col_reductions * args.reduction_size; r += BM) { + U vals[N_READS]; + cub::LoadDirectBlocked( + column, + make_cast_iterator(in + loop.location() + in_offset), + vals, + args.reduction_stride - in_offset, + ReduceInit::value()); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / n_warps; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[col + i] = totals[i]; + } + block.sync(); + col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[col + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + size_t out_offset = grid.block_index().x * BN; + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + out_idx * args.reduction_stride + out_offset, + totals, + args.reduction_stride - out_offset); + } +} + +} // namespace cu + +inline auto output_grid_for_col_reduce( + const array& out, + const cu::ColReduceArgs& args) { + auto out_shape = out.shape(); + auto out_strides = out.strides(); + while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { + out_shape.pop_back(); + out_strides.pop_back(); + } + return get_2d_grid_dims(out_shape, out_strides); +} + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::ColReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + constexpr int N_READS = 4; + dim3 block_dims; + dim3 num_blocks = output_grid_for_col_reduce(out, args); + num_blocks.z = num_blocks.y; + num_blocks.y = num_blocks.x; + auto kernel = + cu::col_reduce_small; + size_t total = args.non_col_reductions * args.reduction_size; + if (total < 32) { + size_t stride_blocks = + cuda::ceil_div(args.reduction_stride, N_READS); + block_dims.x = std::min(stride_blocks, 32ul); + block_dims.y = std::min(total, 8ul); + num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); + } else { + constexpr int BM = 32; + constexpr int BN = 32; + block_dims.x = BM * BN / N_READS; + num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); + kernel = cu:: + col_reduce_looped; + } + kernel<<>>( + in.data(), out.data(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh new file mode 100644 index 000000000..0148022ab --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -0,0 +1,74 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +// Dispatch dynamic ndim to constexpr. +// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. +#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ + if (ndim == 1) { \ + constexpr uint32_t NDIM = 1; \ + __VA_ARGS__; \ + } else if (ndim == 2) { \ + constexpr uint32_t NDIM = 2; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t NDIM = 5; \ + __VA_ARGS__; \ + } + +// Dispatch reduce ops to constexpr. +#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ + if (REDUCE == Reduce::ReduceType::And) { \ + using OP = cu::And; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Or) { \ + using OP = cu::Or; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Sum) { \ + using OP = cu::Sum; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Prod) { \ + using OP = cu::Prod; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Max) { \ + using OP = cu::Max; \ + __VA_ARGS__; \ + } else if (REDUCE == Reduce::ReduceType::Min) { \ + using OP = cu::Min; \ + __VA_ARGS__; \ + } else { \ + throw std::invalid_argument("Unknown reduce type."); \ + } + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void col_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh new file mode 100644 index 000000000..f06eb8541 --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -0,0 +1,144 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/utils.cuh" + +namespace mlx::core::cu { + +// Reduce ops. +struct And { + __device__ bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct Or { + __device__ bool operator()(bool a, bool b) { + return a || b; + } +}; + +struct Sum { + template + __device__ T operator()(T a, T b) { + return a + b; + } +}; + +struct Prod { + template + __device__ T operator()(T a, T b) { + return a * b; + } +}; + +struct Min { + template + __device__ T operator()(T a, T b) { + return a < b ? a : b; + } +}; + +struct Max { + template + __device__ T operator()(T a, T b) { + return a > b ? a : b; + } +}; + +// Traits to get the result type of reduce op. +template +struct ReduceResult; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = cuda::std::conditional_t< + (cuda::std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +template +struct ReduceResult { + using type = T; +}; + +template +struct ReduceResult { + using type = T; +}; + +// Traits to get the init value of reduce op. +template +struct ReduceInit; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return true; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ bool value() { + return false; + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (cuda::std::is_same_v) { + return T{0, 0}; + } else { + return typename ReduceResult::type{0}; + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ auto value() { + if constexpr (cuda::std::is_same_v) { + return T{1, 1}; + } else { + return typename ReduceResult::type{1}; + } + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::max(); + } +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu new file mode 100644 index 000000000..3a5c4a591 --- /dev/null +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -0,0 +1,250 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } +}; + +template +__global__ void row_reduce_small( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + size_t out_idx = cg::this_grid().thread_rank(); + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + out[out_idx] = total_val; +} + +template +__global__ void row_reduce_small_warp( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.thread_rank() / WARP_SIZE; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = warp.thread_rank(); n < args.non_row_reductions; + n += WARP_SIZE) { + for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); + } + + total_val = cg::reduce(warp, total_val, op); + + if (warp.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BLOCK_DIM_X, + int N_READS = 4> +__global__ void row_reduce_looped( + const T* in, + U* out, + size_t out_size, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; + if (out_idx >= out_size) { + return; + } + + Op op; + + U total_val = ReduceInit::value(); + LoopedElemToLoc 2)> loop(args.reduce_ndim); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + + for (size_t n = 0; n < args.non_row_reductions; n++) { + for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); + r++) { + U vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM_X + block.thread_index().x, + make_cast_iterator(in + loop.location()), + vals, + args.row_size, + ReduceInit::value()); + total_val = op(total_val, cub::ThreadReduce(vals, op)); + } + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + + typedef cub::BlockReduce BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + total_val = BlockReduceT(temp).Reduce(total_val, op); + + if (block.thread_rank() == 0) { + out[out_idx] = total_val; + } +} + +} // namespace cu + +void row_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + cu::RowReduceArgs args(in, plan, axes); + + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + using InType = cuda_type_t; + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using OutType = cu::ReduceResult::type; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + constexpr size_t N_READS = 4; + dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims, num_blocks; + auto kernel = + cu::row_reduce_small; + if (args.row_size <= 64) { + if ((args.non_row_reductions < 32 && args.row_size <= 8) || + (args.non_row_reductions <= 8)) { + block_dims.x = std::min(out_dims.x, 1024u); + num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); + num_blocks.y = out_dims.y; + } else { + block_dims.x = WARP_SIZE; + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + kernel = + cu::row_reduce_small_warp; + } + } else { + size_t num_threads = cuda::ceil_div(args.row_size, N_READS); + num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; + MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { + num_blocks.y = out_dims.x; + num_blocks.z = out_dims.y; + block_dims.x = BLOCK_DIM_X; + kernel = cu::row_reduce_looped< + InType, + OutType, + OP, + NDIM, + BLOCK_DIM_X, + N_READS>; + }); + } + kernel<<>>( + in.data(), out.data(), out.size(), args); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu new file mode 100644 index 000000000..563b056e4 --- /dev/null +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -0,0 +1,84 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +template +void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); +} + +template +void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { + // Allocate temporary storage. + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + // Run op. + CHECK_CUDA_ERROR( + cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); +} + +struct MultiplyOp { + int factor; + __device__ int operator()(int i) { + return i * factor; + } +}; + +void segmented_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using InType = cuda_type_t; + using OutType = cu::ReduceResult::type; + auto in_iter = cu::make_cast_iterator( + thrust::device_pointer_cast(in.data())); + auto out_ptr = thrust::device_pointer_cast(out.data()); + auto init = cu::ReduceInit::value(); + + if (plan.type == ContiguousAllReduce) { + cub_all_reduce( + encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); + } else if (plan.type == ContiguousReduce) { + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); + cub_segmented_reduce( + encoder, + in_iter, + out_ptr, + out.size(), + offsets, + offsets + 1, + OP(), + init, + stream); + } else { + throw std::runtime_error("Unsupported plan in segmented_reduce."); + } + }); + }); + }); +} + +} // namespace mlx::core From ccf78f566ca8eae9de82b4bf64f043078db2a0a8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 05:26:17 +0900 Subject: [PATCH 091/156] CUDA backend: argreduce (#2270) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/arg_reduce.cu | 189 ++++++++++++++++++ .../cuda/iterators/strided_iterator.cuh | 60 ++++++ mlx/backend/cuda/primitives.cu | 1 - 4 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/arg_reduce.cu create mode 100644 mlx/backend/cuda/iterators/strided_iterator.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c053b4428..ab0d5fe7c 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu new file mode 100644 index 000000000..7dbd91e46 --- /dev/null +++ b/mlx/backend/cuda/arg_reduce.cu @@ -0,0 +1,189 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct IndexValPair { + uint32_t index; + T val; +}; + +template +struct ArgMin { + constexpr __device__ T init() { + return Limits::max(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] < best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +struct ArgMax { + constexpr __device__ T init() { + return Limits::min(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } + + template + __device__ IndexValPair + reduce_many(IndexValPair best, T (&vals)[N], uint32_t offset) { + for (int i = 0; i < N; i++) { + if (vals[i] > best.val) { + best.val = vals[i]; + best.index = offset + i; + } + } + return best; + } +}; + +template +__global__ void arg_reduce_general( + const T* in, + uint32_t* out, + size_t size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides in_strides, + const __grid_constant__ Strides out_strides, + int32_t ndim, + int64_t axis_stride, + int32_t axis_size) { + auto block = cg::this_thread_block(); + + int64_t index = cg::this_grid().block_rank(); + if (index >= size) { + return; + } + + int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); + int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); + + Op op; + T init = op.init(); + IndexValPair best{0, init}; + + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T vals[N_READS]; + auto tid = r * BLOCK_DIM + block.thread_index().z; + cub::LoadDirectBlocked( + tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); + best = op.reduce_many(best, vals, tid * N_READS); + } + + typedef cub::BlockReduce, BLOCK_DIM> BlockReduceT; + __shared__ typename BlockReduceT::TempStorage temp; + + best = BlockReduceT(temp).Reduce(best, op); + + if (block.thread_rank() == 0) { + out[out_idx] = best.index; + } +} + +} // namespace cu + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgReduce::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + + // Prepare the shapes, strides and axis arguments. + Shape shape = remove_index(in.shape(), axis_); + Strides in_strides = remove_index(in.strides(), axis_); + Strides out_strides = out.ndim() == in.ndim() + ? remove_index(out.strides(), axis_) + : out.strides(); + int64_t axis_stride = in.strides()[axis_]; + int32_t axis_size = in.shape()[axis_]; + int32_t ndim = shape.size(); + + // ArgReduce. + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { + using InType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims{1, 1, BLOCK_DIM}; + auto kernel = &cu::arg_reduce_general< + InType, + cu::ArgMax, + BLOCK_DIM, + N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = &cu::arg_reduce_general< + InType, + cu::ArgMin, + BLOCK_DIM, + N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/iterators/strided_iterator.cuh b/mlx/backend/cuda/iterators/strided_iterator.cuh new file mode 100644 index 000000000..3ef8d66bd --- /dev/null +++ b/mlx/backend/cuda/iterators/strided_iterator.cuh @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::cu { + +// RandomAccessIterator for strided access to array entries. +template +class strided_iterator + : public thrust:: + iterator_adaptor, Iterator> { + public: + using super_t = + thrust::iterator_adaptor, Iterator>; + + using reference = typename super_t::reference; + using difference_type = typename super_t::difference_type; + + __host__ __device__ strided_iterator(Iterator it, Stride stride) + : super_t(it), stride_(stride) {} + + __host__ __device__ Stride stride() const { + return stride_; + } + + private: + friend class thrust::iterator_core_access; + + __host__ __device__ bool equal(const strided_iterator& other) const { + return this->base() == other.base(); + } + + __host__ __device__ void advance(difference_type n) { + this->base_reference() += n * stride_; + } + + __host__ __device__ void increment() { + this->base_reference() += stride_; + } + + __host__ __device__ void decrement() { + this->base_reference() -= stride_; + } + + __host__ __device__ difference_type + distance_to(const strided_iterator& other) const { + const difference_type dist = other.base() - this->base(); + _CCCL_ASSERT( + dist % stride() == 0, + "Underlying iterator difference must be divisible by the stride"); + return dist / stride(); + } + + Stride stride_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 1b273e959..5cf19711c 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -72,7 +72,6 @@ bool fast::ScaledDotProductAttention::use_fallback( } NO_GPU(ArgPartition) -NO_GPU(ArgReduce) NO_GPU(BlockMaskedMM) NO_GPU_MULTI(Compiled) NO_GPU(Convolution) From c371baf53a7b3faa32e67bec06482f7b239c1995 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 05:55:22 +0900 Subject: [PATCH 092/156] CUDA backend: softmax (#2272) --- mlx/backend/cuda/CMakeLists.txt | 2 + mlx/backend/cuda/logsumexp.cu | 159 +++++++++++++++++++++++++++++++ mlx/backend/cuda/primitives.cu | 2 - mlx/backend/cuda/softmax.cu | 160 ++++++++++++++++++++++++++++++++ 4 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/logsumexp.cu create mode 100644 mlx/backend/cuda/softmax.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index ab0d5fe7c..410e24096 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu @@ -27,6 +28,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu new file mode 100644 index 000000000..e539ac559 --- /dev/null +++ b/mlx/backend/cuda/logsumexp.cu @@ -0,0 +1,159 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void logsumexp(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::finite_min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + + // Write output. + if (block.thread_rank() == 0) { + out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval; + } +} + +} // namespace cu + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("LogSumExp::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Make sure that the last dimension is contiguous. + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& s : strides) { + s /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + allocator::malloc(in.nbytes() / n), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::logsumexp; + kernel<<>>( + in.data(), out.data(), axis_size); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 5cf19711c..47bf68172 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -85,7 +85,6 @@ NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) NO_GPU(Load) -NO_GPU(LogSumExp) NO_GPU_MULTI(LUF) NO_GPU(Partition) NO_GPU_MULTI(QRF) @@ -95,7 +94,6 @@ NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) -NO_GPU(Softmax) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu new file mode 100644 index 000000000..605fc0df8 --- /dev/null +++ b/mlx/backend/cuda/softmax.cu @@ -0,0 +1,160 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return __expf(x); +} + +template +__global__ void softmax(const T* in, T* out, int axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + cg::greater max_op; + cg::plus plus_op; + + // Thread reduce. + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + AccT vals[N_READS]; + cub::LoadDirectBlocked( + r * BLOCK_DIM + block.thread_rank(), + make_cast_iterator(in), + vals, + axis_size, + Limits::finite_min()); + prevmax = maxval; + maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); + // Online normalizer calculation for softmax: + // https://github.com/NVIDIA/online-softmax + normalizer = normalizer * softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce. + prevmax = maxval; + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = cg::reduce(warp, normalizer, plus_op); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce. + prevmax = maxval; + if (warp.thread_rank() == 0) { + local_max[warp.meta_group_rank()] = maxval; + } + block.sync(); + maxval = warp.thread_rank() < warp.meta_group_size() + ? local_max[warp.thread_rank()] + : Limits::finite_min(); + maxval = cg::reduce(warp, maxval, max_op); + normalizer = normalizer * softmax_exp(prevmax - maxval); + if (warp.thread_rank() == 0) { + local_normalizer[warp.meta_group_rank()] = normalizer; + } + block.sync(); + normalizer = warp.thread_rank() < warp.meta_group_size() + ? local_normalizer[warp.thread_rank()] + : AccT{}; + normalizer = cg::reduce(warp, normalizer, plus_op); + normalizer = 1 / normalizer; + + // Write output. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T vals[N_READS]; + cub::LoadDirectBlocked(index, in, vals, axis_size); + for (int i = 0; i < N_READS; i++) { + vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; + } + cub::StoreDirectBlocked(index, out, vals, axis_size); + } +} + +} // namespace cu + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Softmax::eval_gpu"); + assert(inputs.size() == 1); + auto& s = stream(); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array in = set_output(inputs[0]); + bool precise = in.dtype() != float32 && precise_; + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + kernel<<>>( + in.data(), out.data(), axis_size); + }); + }); + }); +} + +} // namespace mlx::core From d7e680ffe486edf1aee8b8bb9acbf84dc31b3c4f Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 07:48:32 +0900 Subject: [PATCH 093/156] CUDA backend: layernorm (#2271) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/layer_norm.cu | 390 ++++++++++++++++++++++++++++++++ mlx/backend/cuda/primitives.cu | 2 - 3 files changed, 391 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/layer_norm.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 410e24096..d5041b2ae 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu new file mode 100644 index 000000000..5aa287603 --- /dev/null +++ b/mlx/backend/cuda/layer_norm.cu @@ -0,0 +1,390 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +inline __device__ float3 plus_f3(const float3& a, const float3& b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +// Similar to cub::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, cg::plus{}, T{}); + } +}; + +template +__global__ void layer_norm( + const T* x, + const T* w, + const T* b, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride, + int64_t b_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + cub::LoadDirectBlocked(index, x, xn, axis_size); + sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + } + sum = BlockReduceT{block, temp}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size, mean); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]) - mean; + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T bn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = (static_cast(xn[i]) - mean) * normalizer; + xn[i] = wn[i] * static_cast(norm) + bn[i]; + } + cub::StoreDirectBlocked(index, out, xn, axis_size); + } +} + +template +__global__ void layer_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF3 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF3::TempStorage f3; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Sum. + float sum = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS] = {}; + cub::LoadDirectBlocked(index, x, xn, axis_size); + sum += static_cast(cub::ThreadReduce(xn, cuda::std::plus<>{})); + } + sum = BlockReduceF{block, temp.f}.Sum(sum); + + // Mean. + float mean = sum / axis_size; + + // Normalizer. + float3 factors = {}; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + cub::LoadDirectBlocked(index, x, xn, axis_size, mean); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]) - mean; + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f3(factors, {wg, wg * t, t * t}); + } + } + factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {}); + float meanwg = factors.x / axis_size; + float meanwgxc = factors.y / axis_size; + float normalizer2 = 1 / (factors.z / axis_size + eps); + float normalizer = sqrt(normalizer2); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = (static_cast(xn[i]) - mean) * normalizer; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2; + if constexpr (HAS_W) { + wn[i] = gi * xi; + } + } + cub::StoreDirectBlocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + cub::StoreDirectBlocked(index, gw, wn, axis_size); + } + } +} + +} // namespace cu + +namespace fast { + +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void LayerNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("LayerNorm::eval_gpu"); + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array o = set_output(inputs[0]); + const array& x = o.data_shared_ptr() ? o : out; + const array& w = inputs[1]; + const array& b = inputs[2]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { + using DataType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::layer_norm; + kernel<<>>( + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); + }); + }); +} + +void LayerNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("LayerNormVJP::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x) -> std::pair { + if (x.flags().row_contiguous) { + return {x, false}; + } + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return {x_copy, true}; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[3].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; + const array& w = inputs[1]; + const array& b = inputs[2]; + auto [g, g_copied] = check_input(inputs[3]); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + array& gb = outputs[2]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + gw.set_data(allocator::malloc(gw.nbytes())); + gb.set_data(allocator::malloc(gb.nbytes())); + + // Finish with the gradient for b in case we had a b. + if (gb.ndim() == 1 && gb.size() == axis_size) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); + } + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BOOL(has_w, HAS_W, { + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::layer_norm_vjp; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 47bf68172..8de4f92f9 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -101,8 +101,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_USE_FALLBACK(LayerNorm) -NO_GPU_MULTI(LayerNormVJP) NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) From c2dd81a8aa95e04d16c20c587a9faedec10c48ed Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 22:03:01 +0900 Subject: [PATCH 094/156] Fix warnings from latest CUDA toolkit (#2275) --- mlx/backend/cuda/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d5041b2ae..e4f36074a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -44,12 +44,16 @@ target_compile_options(mlx # CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. # Explicitly pass this flag to suppress the warning, it is safe to set it to # true but the warning wouldn't be suppressed. -if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0) target_compile_options( mlx PRIVATE "$<$:--static-global-template-stub=false>") endif() +# Suppress warning when building for compute capability 7 used by V100. +target_compile_options( + mlx PRIVATE "$<$:--Wno-deprecated-gpu-targets>") + # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES From f5f65ef48cb3e77561a4a7835356bd4c419ccd15 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 16:48:54 -0700 Subject: [PATCH 095/156] Make sliceUpdate general (#2282) * Make sliceUpdate general * fix --- mlx/backend/cuda/primitives.cu | 1 - mlx/backend/gpu/primitives.cpp | 36 ++++++++++++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 35 ------------------------------- mlx/distributed/mpi/mpi.cpp | 2 ++ 4 files changed, 38 insertions(+), 36 deletions(-) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 8de4f92f9..5ffe0e10d 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -93,7 +93,6 @@ NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) -NO_GPU(SliceUpdate) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 938923977..1adb85918 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/primitives.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" @@ -170,6 +171,41 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { slice_gpu(in, out, start_indices_, strides_, stream()); } +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); +} + void Squeeze::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Squeeze::eval_gpu"); eval(inputs, out); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 705c3ea76..2ac543ad8 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu( /* const std::optional& dynamic_o_offset = */ out_offset); } -void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - auto [data_offset, out_strides] = - prepare_slice(out, start_indices_, strides_); - - // Do copy - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out_strides, - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ stream()); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index e80a1759f..6a440c319 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -225,6 +225,8 @@ struct MPIWrapper { return mpi_bfloat16_; case float64: return mpi_double_; + default: + throw std::runtime_error("Invalid type"); } } From a4fc671d3e5fbd548433a592d3ec160973b97aa1 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 13 Jun 2025 09:08:39 +0900 Subject: [PATCH 096/156] CUDA backend: compile (#2276) * CUDA backend: compile * Rename kernels/ to device/ --- mlx/backend/cuda/CMakeLists.txt | 23 ++ mlx/backend/cuda/bin2h.cmake | 150 ++++++++ mlx/backend/cuda/binary.cu | 4 +- mlx/backend/cuda/compiled.cpp | 228 ++++++++++++ mlx/backend/cuda/copy/copy.cuh | 2 +- .../cuda/{kernels => device}/arange.cuh | 0 .../cuda/{kernels => device}/binary_ops.cuh | 2 +- .../cuda/{kernels => device}/cast_op.cuh | 0 mlx/backend/cuda/device/config.h | 12 + .../{kernels => device}/cucomplex_math.cuh | 0 .../cuda/{kernels => device}/fp16_math.cuh | 0 .../cuda/{kernels => device}/unary_ops.cuh | 4 +- .../cuda/{kernels => device}/utils.cuh | 8 +- mlx/backend/cuda/jit_module.cpp | 340 ++++++++++++++++++ mlx/backend/cuda/jit_module.h | 113 ++++++ mlx/backend/cuda/kernel_utils.cuh | 4 +- mlx/backend/cuda/logsumexp.cu | 2 +- mlx/backend/cuda/primitives.cu | 5 +- mlx/backend/cuda/reduce/col_reduce.cu | 2 +- mlx/backend/cuda/reduce/reduce.cuh | 2 +- mlx/backend/cuda/reduce/reduce_ops.cuh | 2 +- mlx/backend/cuda/reduce/row_reduce.cu | 2 +- mlx/backend/cuda/reduce/segmented_reduce.cu | 2 +- mlx/backend/cuda/softmax.cu | 4 +- mlx/backend/cuda/unary.cu | 4 +- mlx/backend/cuda/utils.cpp | 17 + mlx/backend/cuda/utils.h | 5 + 27 files changed, 910 insertions(+), 27 deletions(-) create mode 100644 mlx/backend/cuda/bin2h.cmake create mode 100644 mlx/backend/cuda/compiled.cpp rename mlx/backend/cuda/{kernels => device}/arange.cuh (100%) rename mlx/backend/cuda/{kernels => device}/binary_ops.cuh (99%) rename mlx/backend/cuda/{kernels => device}/cast_op.cuh (100%) create mode 100644 mlx/backend/cuda/device/config.h rename mlx/backend/cuda/{kernels => device}/cucomplex_math.cuh (100%) rename mlx/backend/cuda/{kernels => device}/fp16_math.cuh (100%) rename mlx/backend/cuda/{kernels => device}/unary_ops.cuh (98%) rename mlx/backend/cuda/{kernels => device}/utils.cuh (97%) create mode 100644 mlx/backend/cuda/jit_module.cpp create mode 100644 mlx/backend/cuda/jit_module.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index e4f36074a..854c6c116 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu @@ -18,6 +19,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu @@ -37,6 +39,24 @@ target_sources( target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) +# Embed kernel sources in binary for JIT compilation. +file( + GLOB MLX_JIT_SOURCES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh") +string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) +add_custom_command( + OUTPUT gen/cuda_jit_sources.h + COMMAND + ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} + -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P + "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" + DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) +add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h) +add_dependencies(mlx cuda_jit_sources) +target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") + # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") @@ -87,6 +107,9 @@ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) # Use cublasLt. target_link_libraries(mlx PRIVATE CUDA::cublasLt) +# Use NVRTC and driver APIs. +target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) + # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) diff --git a/mlx/backend/cuda/bin2h.cmake b/mlx/backend/cuda/bin2h.cmake new file mode 100644 index 000000000..b791d3d1a --- /dev/null +++ b/mlx/backend/cuda/bin2h.cmake @@ -0,0 +1,150 @@ +# Based on: https://github.com/sivachandran/cmake-bin2h +# +# Copyright 2020 Sivachandran Paramasivam +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +include(CMakeParseArguments) + +# Function to wrap a given string into multiple lines at the given column +# position. +# +# Parameters: +# +# * VARIABLE - The name of the CMake variable holding the string. +# * AT_COLUMN - The column position at which string will be wrapped. +function(WRAP_STRING) + set(oneValueArgs VARIABLE AT_COLUMN) + cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN}) + + string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength) + math(EXPR offset "0") + + while(stringLength GREATER 0) + if(stringLength GREATER ${WRAP_STRING_AT_COLUMN}) + math(EXPR length "${WRAP_STRING_AT_COLUMN}") + else() + math(EXPR length "${stringLength}") + endif() + + string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line) + set(lines "${lines}\n ${line}") + + math(EXPR stringLength "${stringLength} - ${length}") + math(EXPR offset "${offset} + ${length}") + endwhile() + + set(${WRAP_STRING_VARIABLE} + "${lines}" + PARENT_SCOPE) +endfunction() + +# Function to embed contents of a file as byte array in C/C++ header file(.h). +# The header file will contain a byte array and integer variable holding the +# size of the array. +# +# Parameters: +# +# * SOURCE_FILES - The paths of source files whose contents will be embedded in +# the header file. +# * VARIABLE_NAME - The name of the variable for the byte array. The string +# "_SIZE" will be append to this name and will be used a variable name for +# size variable. +# * HEADER_FILE - The path of header file. +# * APPEND - If specified appends to the header file instead of overwriting it +# * HEADER_NAMESPACE - The namespace, where the array should be located in. +# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte +# array. +# +# Usage: +# +# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG") +function(BIN2H) + set(options APPEND NULL_TERMINATE) + set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE) + set(multiValueArgs SOURCE_FILES) + cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(arrayDefinition "") + foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES) + # get filename without extension + get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE) + # convert the filename to a valid C identifier + string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME) + + # reads source file contents as hex string + file(READ ${SOURCE_FILE} hexString HEX) + + # append null + if(BIN2H_NULL_TERMINATE) + string(APPEND hexString "00") + endif() + + # wraps the hex string into multiple lines + wrap_string(VARIABLE hexString AT_COLUMN 24) + + # strip the © in source code + string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString}) + + string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues + ${arrayValues}) + + # make a full variable name for the array + set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}") + + # declares byte array and the length variables + string(APPEND arrayDefinition + "constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n") + endforeach() + + # add namespace wrapper if defined + if(DEFINED BIN2H_HEADER_NAMESPACE) + set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {") + set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}") + set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n") + endif() + + set(arrayIncludes "#pragma once") + string(PREPEND declarations "${arrayIncludes}\n\n") + + if(BIN2H_APPEND) + file(APPEND ${BIN2H_HEADER_FILE} "${declarations}") + else() + file(WRITE ${BIN2H_HEADER_FILE} "${declarations}") + endif() +endfunction() + +# ----------------------------- CLI args ----------------------------- + +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) +foreach(source ${MLX_JIT_SOURCES_LIST}) + list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}") +endforeach() + +bin2h( + SOURCE_FILES + ${MLX_JIT_SOURCES_ABS} + NULL_TERMINATE + VARIABLE_NAME + "jit_source" + HEADER_NAMESPACE + "mlx::core" + HEADER_FILE + "${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h") diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 360772998..47efc44d2 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -2,9 +2,9 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/binary_ops.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp new file mode 100644 index 000000000..a6b8223e0 --- /dev/null +++ b/mlx/backend/cuda/compiled.cpp @@ -0,0 +1,228 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +struct FusedKernelBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& outputs; + const std::vector& tape; + const std::function& is_constant; + + void build(const char* name, bool contiguous) { + NodeNamer namer; + + // Function parameters. + std::vector params; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + continue; + } + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + params.push_back( + fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname)); + if (!is_scalar(x) && !contiguous) { + params.push_back(fmt::format( + "const __grid_constant__ cuda::std::array {}_strides", + xname)); + } + } + for (const auto& x : outputs) { + params.push_back(fmt::format( + "{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x))); + } + if (!contiguous) { + params.push_back( + "const __grid_constant__ cuda::std::array shape"); + } + params.push_back("IdxT size"); + + // Build function signature. + if (contiguous) { + os += "template \n"; + } else { + os += "template \n"; + } + os += fmt::format("__global__ void {}(\n", kernel_name + name); + for (size_t i = 0; i < params.size(); ++i) { + os += " "; + os += params[i]; + if (i != params.size() - 1) { + os += ",\n"; + } + } + os += ") {\n"; + + // Index. + os += + " IdxT index = cg::this_grid().thread_rank();\n" + " if (index >= size) {\n" + " return;\n" + " }\n"; + + // Read inputs. + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + value = fmt::format("static_cast<{}>({})", type, ss.str()); + } else if (is_scalar(x)) { + value = fmt::format("{}[0]", xname); + } else if (contiguous) { + value = fmt::format("{}[index]", xname); + } else { + std::string index = fmt::format( + "elem_to_loc_nd(index, shape.data(), {}_strides.data())", + xname); + value = fmt::format("{}[{}]", xname, index); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write tape. + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = fmt::format( + "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); + } else { + std::ostringstream ss; + x.primitive().print(ss); + value = ss.str(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); + } + value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write output. + for (const auto& x : outputs) { + os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + } + + os += "}\n"; + } +}; + +} // namespace cu + +constexpr const char* g_jit_includes = R"( +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +)"; + +void Compiled::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("Compiled::eval_gpu"); + auto& s = stream(); + + cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { + // Build source code. + cu::FusedKernelBuilder builder{ + g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; + builder.os += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + builder.build("_contiguous", true); + builder.os += "\n"; + builder.build("_strided", false); + builder.os += "\n} // namespace mlx::core::cu\n"; + // Build kernel names. + std::vector kernel_names = { + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + }; + for (int i = 1; i <= MAX_NDIM; ++i) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i)); + kernel_names.push_back( + fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i)); + } + return std::make_pair(std::move(builder.os), std::move(kernel_names)); + }); + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides_vec] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); + + // Put inputs. + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { + continue; + } + const auto& x = inputs[i]; + mod.append_arg(x); + if (!contiguous && !is_scalar(x)) { + mod.append_arg(strides_vec[strides_index++]); + } + } + + // Put outputs. + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + for (auto& x : outputs) { + mod.append_arg(x); + } + + // Put shape and size. + if (!contiguous) { + mod.append_arg(shape); + } + if (large) { + mod.append_arg(outputs[0].data_size()); + } else { + mod.append_arg(outputs[0].data_size()); + } + + // Launch kernel. + const char* index_type = large ? "int64_t" : "uint32_t"; + std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); + if (contiguous) { + kernel_name += fmt::format("_contiguous<{}>", index_type); + } else { + kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type); + } + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, outputs[0], large); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index dd1d09d30..0c1eff774 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -3,8 +3,8 @@ #pragma once #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/device/arange.cuh similarity index 100% rename from mlx/backend/cuda/kernels/arange.cuh rename to mlx/backend/cuda/device/arange.cuh diff --git a/mlx/backend/cuda/kernels/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh similarity index 99% rename from mlx/backend/cuda/kernels/binary_ops.cuh rename to mlx/backend/cuda/device/binary_ops.cuh index 3bc30eb02..4779a6f33 100644 --- a/mlx/backend/cuda/kernels/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include #include diff --git a/mlx/backend/cuda/kernels/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh similarity index 100% rename from mlx/backend/cuda/kernels/cast_op.cuh rename to mlx/backend/cuda/device/cast_op.cuh diff --git a/mlx/backend/cuda/device/config.h b/mlx/backend/cuda/device/config.h new file mode 100644 index 000000000..0933cc8b5 --- /dev/null +++ b/mlx/backend/cuda/device/config.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both CUDA kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 8 + +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 diff --git a/mlx/backend/cuda/kernels/cucomplex_math.cuh b/mlx/backend/cuda/device/cucomplex_math.cuh similarity index 100% rename from mlx/backend/cuda/kernels/cucomplex_math.cuh rename to mlx/backend/cuda/device/cucomplex_math.cuh diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/device/fp16_math.cuh similarity index 100% rename from mlx/backend/cuda/kernels/fp16_math.cuh rename to mlx/backend/cuda/device/fp16_math.cuh diff --git a/mlx/backend/cuda/kernels/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh similarity index 98% rename from mlx/backend/cuda/kernels/unary_ops.cuh rename to mlx/backend/cuda/device/unary_ops.cuh index 6637a6eeb..af7c30e64 100644 --- a/mlx/backend/cuda/kernels/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,8 +2,8 @@ #pragma once -#include "mlx/backend/cuda/kernels/fp16_math.cuh" -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/device/utils.cuh similarity index 97% rename from mlx/backend/cuda/kernels/utils.cuh rename to mlx/backend/cuda/device/utils.cuh index 7636710dc..a1d387201 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -8,6 +8,8 @@ #pragma once +#include "mlx/backend/cuda/device/config.h" + #include #include #include @@ -21,14 +23,8 @@ namespace mlx::core::cu { // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// -// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in -// warpSize variable exists, using it would prevent compile-time optimizations. -#define WARP_SIZE 32 - // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. -#define MAX_NDIM 8 - using Shape = cuda::std::array; using Strides = cuda::std::array; diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp new file mode 100644 index 000000000..3c00dd7f0 --- /dev/null +++ b/mlx/backend/cuda/jit_module.cpp @@ -0,0 +1,340 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/device.h" + +#include "cuda_jit_sources.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +namespace { + +#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd)) + +void check_nvrtc_error(const char* name, nvrtcResult err) { + if (err != NVRTC_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, nvrtcGetErrorString(err))); + } +} + +#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd)) + +void check_cu_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + +// Return the location of the CUDA toolkit. +const char* cuda_home() { + const char* home = std::getenv("CUDA_HOME"); + if (home) { + return home; + } + home = std::getenv("CUDA_PATH"); + if (home) { + return home; + } +#if defined(__linux__) + home = "/usr/local/cuda"; + if (std::filesystem::exists(home)) { + return home; + } +#endif + throw std::runtime_error( + "Environment variable CUDA_HOME or CUDA_PATH is not set."); +} + +// Get the cache directory for storing compiled results. +bool get_ptx_cache_dir(std::filesystem::path* result) { + auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; + if (!std::filesystem::is_directory(path)) { + std::error_code error; + if (!std::filesystem::create_directories(path, error)) { + return false; + } + } + *result = path; + return true; +} + +// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. +bool read_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + std::vector* ptx, + std::vector>* ptx_kernels) { + auto ptx_path = cache_dir / (module_name + ".ptx"); + std::error_code error; + auto ptx_size = std::filesystem::file_size(ptx_path, error); + if (error) { + return false; + } + std::ifstream ptx_file(ptx_path, std::ios::binary); + if (!ptx_file.good()) { + return false; + } + ptx->resize(ptx_size); + ptx_file.read(ptx->data(), ptx_size); + + std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + std::string line; + while (std::getline(txt_file, line)) { + auto tab = line.find('\t'); + if (tab != std::string::npos) { + ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); + } + } + return true; +} + +// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|. +void write_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + const std::vector& ptx, + const std::vector>& ptx_kernels) { + std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); + if (!ptx.empty()) { + ptx_file.write(&ptx.front(), ptx.size()); + } + std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + for (const auto& [name, mangled] : ptx_kernels) { + txt_file << name << "\t" << mangled << std::endl; + } +} + +// Return if |device|'s version is not newer than |major|.|minor| version. +inline bool version_lower_equal(Device& device, int major, int minor) { + if (device.compute_capability_major() < major) { + return true; + } else if (device.compute_capability_major() == major) { + return device.compute_capability_minor() <= minor; + } else { + return false; + } +} + +// Return whether NVRTC supports compiling to |device|'s SASS code. +bool compiler_supports_device_sass(Device& device) { + int nvrtc_major, nvrtc_minor; + CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor)); + if (nvrtc_major < 9) { + return false; + } else if (nvrtc_major == 9) { + return version_lower_equal(device, 7, 2); + } else if (nvrtc_major == 10) { + return version_lower_equal(device, 7, 5); + } else if (nvrtc_major == 11 && nvrtc_minor == 0) { + return version_lower_equal(device, 8, 0); + } else if (nvrtc_major == 11 && nvrtc_minor < 8) { + return version_lower_equal(device, 8, 6); + } else { + return true; + } +} + +#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" + +constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "binary_ops.cuh", + INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "config.h", + INCLUDE_PREFIX "cucomplex_math.cuh", + INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "unary_ops.cuh", + INCLUDE_PREFIX "utils.cuh", +}; + +#undef INCLUDE_PREFIX + +constexpr const char* g_headers[] = { + jit_source_binary_ops, + jit_source_cast_op, + jit_source_config, + jit_source_cucomplex_math, + jit_source_fp16_math, + jit_source_unary_ops, + jit_source_utils, +}; + +} // namespace + +JitModule::JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder) { + // Check cache. + std::filesystem::path cache_dir; + std::vector ptx; + std::vector> ptx_kernels; + if (!get_ptx_cache_dir(&cache_dir) || + !read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) { + // Create program. + auto [source_code, kernel_names] = builder(); + nvrtcProgram prog; + CHECK_NVRTC_ERROR(nvrtcCreateProgram( + &prog, + source_code.c_str(), + (module_name + ".cu").c_str(), + std::size(g_headers), + g_headers, + g_include_names)); + std::unique_ptr prog_freer( + &prog, + [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); + for (const auto& name : kernel_names) { + CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); + } + + // Compile program. + bool use_sass = compiler_supports_device_sass(device); + std::string compute = fmt::format( + "--gpu-architecture={}_{}{}", + use_sass ? "sm" : "compute", + device.compute_capability_major(), + device.compute_capability_minor()); + std::string include = fmt::format("--include-path={}/include", cuda_home()); + const char* args[] = {compute.c_str(), include.c_str()}; + nvrtcResult compile_result = + nvrtcCompileProgram(prog, std::size(args), args); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size; + CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); + throw std::runtime_error( + fmt::format("Failed to compile kernel: {}.", log.data())); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); + ptx_kernels.emplace_back(name, mangled); + } + + // Get ptx data. + size_t ptx_size; + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); + } + ptx.resize(ptx_size, 0); + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); + } + write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); + } + + // Load module. + char jit_log[4089] = {}; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; + CUresult jit_result = cuModuleLoadDataEx( + &module_, ptx.data(), std::size(options), options, values); + if (jit_result != CUDA_SUCCESS) { + throw std::runtime_error(fmt::format( + "Failed to load compiled {} kernel: {}.", module_name, jit_log)); + } + + // Load kernels. + for (const auto& [name, mangled] : ptx_kernels) { + CUfunction kernel; + CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + kernels_[name] = kernel; + } +} + +JitModule::~JitModule() { + CHECK_CU_ERROR(cuModuleUnload(module_)); +} + +void JitModule::launch_kernel( + CUstream stream, + const std::string& kernel_name, + const array& arr, + bool large, + int work_per_thread) { + CUfunction kernel = get_kernel(kernel_name); + size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + int _, block_dim; + CHECK_CU_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + if (block_dim > nthreads) { + block_dim = nthreads; + } + Dims num_blocks{1, 1, 1}; + if (large) { + num_blocks = + get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread); + std::get<0>(num_blocks) = + (std::get<0>(num_blocks) + block_dim - 1) / block_dim; + } else { + std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim; + } + launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1}); +} + +void JitModule::launch_kernel( + CUstream stream, + CUfunction kernel, + Dims num_blocks, + Dims block_dims) { + CHECK_CU_ERROR(cuLaunchKernel( + kernel, + std::get<0>(num_blocks), + std::get<1>(num_blocks), + std::get<2>(num_blocks), + std::get<0>(block_dims), + std::get<1>(block_dims), + std::get<2>(block_dims), + 0, + stream, + args_.data(), + nullptr)); + args_.clear(); + storage_.clear(); +} + +CUfunction JitModule::get_kernel(const std::string& kernel_name) { + auto it = kernels_.find(kernel_name); + if (it == kernels_.end()) { + throw std::runtime_error( + fmt::format("There is no kernel named {}.", kernel_name)); + } + return it->second; +} + +void JitModule::append_ptr_arg(const void* v) { + args_.push_back(const_cast(v)); +} + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder) { + static std::unordered_map map; + auto it = map.find(name); + if (it == map.end()) { + it = map.try_emplace(name, cu::device(device), name, builder).first; + } + return it->second; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h new file mode 100644 index 000000000..bbfaa45b0 --- /dev/null +++ b/mlx/backend/cuda/jit_module.h @@ -0,0 +1,113 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device/config.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +using KernelBuilderResult = std::pair< + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + + void append_arg(const array& a) { + append_arg(reinterpret_cast(a.data())); + } + + template + void append_arg(T val) { + storage_.emplace_back(val); + append_ptr_arg(&storage_.back()); + } + + template + void append_arg(std::vector vec) { + if (vec.empty()) { + // The nullptr can not be used as arg, pass something not null. + append_arg(std::monostate{}); + } else { + append_ptr_arg(vec.data()); + storage_.emplace_back(std::move(vec)); + } + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim_arg(const std::vector& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + std::vector copied(NDIM); + std::copy(vec.begin(), vec.end(), copied.data()); + append_arg(std::move(copied)); + } + + // Launch kernel with |kernel_name| that each thread works on + // |work_per_thread| elements of |arr|. + void launch_kernel( + CUstream stream, + const std::string& kernel_name, + const array& arr, + bool large, + int work_per_thread = 1); + + void launch_kernel( + CUstream stream, + CUfunction kernel, + Dims num_blocks, + Dims block_dims); + + CUfunction get_kernel(const std::string& kernel_name); + + private: + void append_ptr_arg(const void* v); + + CUmodule module_{nullptr}; + std::unordered_map kernels_; + std::vector args_; + + // The cuLaunchKernel API requires passing pointers to arguments so store + // temporary values untill kernel is launched. + using Arg = std::variant< + std::monostate, + CUdeviceptr, + int32_t, + uint32_t, + int64_t, + std::vector, + std::vector, + std::vector>; + std::deque storage_; +}; + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 656ddebea..7e957bbbd 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -1,13 +1,13 @@ // Copyright © 2025 Apple Inc. // This file includes host-only utilies for writing CUDA kernels, the difference -// from backend/cuda/kernels/utils.cuh is that the latter file only include +// from backend/cuda/device/utils.cuh is that the latter file only include // device-only code. #pragma once #include "mlx/array.h" -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include #include diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index e539ac559..f57f82ea8 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 5ffe0e10d..5fef9e150 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -1,9 +1,9 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/arange.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/arange.cuh" -#include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/distributed/primitives.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" @@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback( NO_GPU(ArgPartition) NO_GPU(BlockMaskedMM) -NO_GPU_MULTI(Compiled) NO_GPU(Convolution) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 1ca50d854..9911a6fe0 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 0148022ab..a673e052e 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index f06eb8541..832787222 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 3a5c4a591..ae54a27d6 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu index 563b056e4..114d71809 100644 --- a/mlx/backend/cuda/reduce/segmented_reduce.cu +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 605fc0df8..fc001ae75 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -1,9 +1,9 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" -#include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 0ee31ee28..f9d373455 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -2,10 +2,10 @@ #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" -#include "mlx/backend/cuda/kernels/unary_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2a11a518e..2f5e2a4c8 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" #include @@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +const char* dtype_to_cuda_type(const Dtype& dtype) { + if (dtype == float16) { + return "__half"; + } + if (dtype == bfloat16) { + return "__nv_bfloat16"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (dtype == DTYPE) { \ + return #CPP_TYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return nullptr; +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6eaec8984..6d98cdcd5 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -12,6 +12,8 @@ namespace cu { class Device; } +struct Dtype; + // Cuda stream managed with RAII. class CudaStream { public: @@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + } // namespace mlx::core From 918761a25aeb55cb8c73d3d8e605371ae6e93738 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 17:09:49 -0700 Subject: [PATCH 097/156] [CUDA] RMSNorm and VJP (#2280) * rms norm start * nit --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/layer_norm.cu | 3 +- mlx/backend/cuda/primitives.cu | 2 - mlx/backend/cuda/rms_norm.cu | 343 ++++++++++++++++++++++++++++++++ 4 files changed, 345 insertions(+), 4 deletions(-) create mode 100644 mlx/backend/cuda/rms_norm.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 854c6c116..3e7b859a6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -30,6 +30,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 5aa287603..c71795fad 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -244,8 +244,7 @@ void LayerNorm::eval_gpu( } }; - array o = set_output(inputs[0]); - const array& x = o.data_shared_ptr() ? o : out; + const array x = set_output(inputs[0]); const array& w = inputs[1]; const array& b = inputs[2]; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 5fef9e150..95ea44f94 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -99,8 +99,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_USE_FALLBACK(RMSNorm) -NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu new file mode 100644 index 000000000..3c521b90d --- /dev/null +++ b/mlx/backend/cuda/rms_norm.cu @@ -0,0 +1,343 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/iterators/strided_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +inline __device__ float2 plus_f2(const float2& a, const float2& b) { + return {a.x + b.x, a.y + b.y}; +} + +// Similar to cub::BlockReduce, but result is broadcasted to every thread. +template +struct BlockBroadcastReduce { + static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); + static_assert(BLOCK_DIM % WARP_SIZE == 0); + using TempStorage = T[BLOCK_DIM / WARP_SIZE]; + + cg::thread_block& block; + TempStorage& temp; + + template + __device__ T Reduce(const T& input, const Op& op, const T& init_value) { + auto warp = cg::tiled_partition(block); + T x = cg::reduce(warp, input, op); + if (warp.thread_rank() == 0) { + temp[warp.meta_group_rank()] = x; + } + block.sync(); + x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()] + : init_value; + return cg::reduce(warp, x, op); + } + + __device__ T Sum(const T& input) { + return Reduce(input, cg::plus{}, T{}); + } +}; + +template +__global__ void rms_norm( + const T* x, + const T* w, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceT = BlockBroadcastReduce; + __shared__ typename BlockReduceT::TempStorage temp; + + x += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + // Normalizer. + float normalizer = 0; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + for (int i = 0; i < N_READS; ++i) { + float t = static_cast(xn[i]); + normalizer += t * t; + } + } + normalizer = BlockReduceT{block, temp}.Sum(normalizer); + normalizer = rsqrt(normalizer / axis_size + eps); + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; ++i) { + float norm = static_cast(xn[i]) * normalizer; + xn[i] = wn[i] * static_cast(norm); + } + cub::StoreDirectBlocked(index, out, xn, axis_size); + } +} + +template +__global__ void rms_norm_vjp( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + using BlockReduceF = BlockBroadcastReduce; + using BlockReduceF2 = BlockBroadcastReduce; + __shared__ union { + typename BlockReduceF::TempStorage f; + typename BlockReduceF2::TempStorage f2; + } temp; + + x += grid.block_rank() * axis_size; + g += grid.block_rank() * axis_size; + gx += grid.block_rank() * axis_size; + gw += grid.block_rank() * axis_size; + + // Normalizer. + float2 factors = {}; + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + T xn[N_READS]; + T wn[N_READS] = {}; + T gn[N_READS] = {}; + auto index = r * BLOCK_DIM + block.thread_rank(); + cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float t = static_cast(xn[i]); + float wi = wn[i]; + float gi = gn[i]; + float wg = wi * gi; + factors = plus_f2(factors, {wg * t, t * t}); + } + } + factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); + float meangwx = factors.x / axis_size; + float normalizer = rsqrt(factors.y / axis_size + eps); + float normalizer3 = normalizer * normalizer * normalizer; + + // Outputs. + for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { + auto index = r * BLOCK_DIM + block.thread_rank(); + T xn[N_READS]; + T wn[N_READS]; + T gn[N_READS]; + cub::LoadDirectBlocked(index, x, xn, axis_size); + cub::LoadDirectBlocked(index, g, gn, axis_size); + cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); + for (int i = 0; i < N_READS; i++) { + float xi = xn[i]; + float wi = wn[i]; + float gi = gn[i]; + xn[i] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); + if constexpr (HAS_W) { + wn[i] = static_cast(gi * xi * normalizer); + } + } + cub::StoreDirectBlocked(index, gx, xn, axis_size); + if constexpr (HAS_W) { + cub::StoreDirectBlocked(index, gw, wn, axis_size); + } + } +} + +} // namespace cu + +namespace fast { + +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +// TODO: There are duplicate code with backend/metal/normalization.cpp +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNorm::eval_gpu"); + auto& s = stream(); + auto& out = outputs[0]; + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { + using DataType = cuda_type_t; + constexpr uint32_t N_READS = 4; + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::rms_norm; + kernel<<>>( + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); +} + +void RMSNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RMSNormVJP::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // Ensure row contiguity. We could relax this step by checking that the array + // is contiguous (no broadcasts or holes) and that the input strides are the + // same as the cotangent strides but for now this is simpler. + auto check_input = [&s](const array& x) -> std::pair { + if (x.flags().row_contiguous) { + return {x, false}; + } + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + return {x_copy, true}; + }; + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + auto [x, copied] = check_input(inputs[0]); + donate_x |= copied; + const array& w = inputs[1]; + auto [g, g_copied] = check_input(inputs[2]); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + + // Check whether we had a weight. + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs. + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(allocator::malloc(gx.nbytes())); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); + encoder.add_temporary(gw_temp); + } + } + gw.set_data(allocator::malloc(gw.nbytes())); + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + MLX_SWITCH_BOOL(has_w, HAS_W, { + MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { + auto kernel = cu::rms_norm_vjp; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); + }); + }); + }); + + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core From aa07429bada064b1913031057d230d3a9b5663b0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 17:48:05 -0700 Subject: [PATCH 098/156] Fix cuda build (#2284) --- mlx/backend/cuda/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 3e7b859a6..9b12d84a9 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -44,8 +44,8 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) file( GLOB MLX_JIT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h" - "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh") + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) add_custom_command( OUTPUT gen/cuda_jit_sources.h From 2188199ff80fe8dcfc681368740deef0ca6fd207 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 20:24:43 -0700 Subject: [PATCH 099/156] [CUDA] ternary with select op (#2283) * cuda ternary with select op * comment + fix * fix --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/device/ternary_ops.cuh | 12 ++ mlx/backend/cuda/device/utils.cuh | 41 ++++++ mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/ternary.cu | 177 ++++++++++++++++++++++++ 5 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/cuda/device/ternary_ops.cuh create mode 100644 mlx/backend/cuda/ternary.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9b12d84a9..1567feafd 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -34,6 +34,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh new file mode 100644 index 000000000..d1d008ac5 --- /dev/null +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +struct Select { + template + __device__ T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index a1d387201..6f9851c94 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -162,6 +162,27 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( return cuda::std::make_tuple(a_loc, b_loc); } +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + c_loc += dim_idx * c_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + // Optimized version when ndim is larger than 4. template inline __host__ __device__ IdxT @@ -191,6 +212,26 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( return cuda::std::make_tuple(a_loc, b_loc); } +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides, + int ndim) { + auto [a_loc, b_loc, c_loc] = + elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides); + for (int i = ndim - 1; i >= 3; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + c_loc += dim_idx * c_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 95ea44f94..eb451f49d 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU(QuantizedMatmul) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) -NO_GPU(Select) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu new file mode 100644 index 000000000..bb79d4249 --- /dev/null +++ b/mlx/backend/cuda/ternary.cu @@ -0,0 +1,177 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/ternary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index], c[index]); + } +} + +template +__global__ void ternary_g_nd( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides, + const __grid_constant__ cuda::std::array c_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + } +} + +template +__global__ void ternary_g( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + const __grid_constant__ Strides c_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data(), + ndim); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + } +} + +} // namespace cu + +template +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const Stream& s) { + const auto& a = inputs[0]; + const auto& b = inputs[1]; + const auto& c = inputs[2]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { + using DType = cuda_type_t; + + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || + c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } + }); + }); +} + +template +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + auto topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + ternary_op_gpu_inplace(inputs, out, s); +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("select::eval_gpu"); + auto& s = out.primitive().stream(); + ternary_op_gpu(inputs, out, s); +} + +} // namespace mlx::core From c8b4787e4e34f66f81c612fcbe6371e5e7d308a9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 13 Jun 2025 13:44:19 +0900 Subject: [PATCH 100/156] CUDA backend: indexing ops (#2277) --- mlx/backend/cuda/CMakeLists.txt | 5 +- mlx/backend/cuda/device/atomic_ops.cuh | 72 ++++ mlx/backend/cuda/device/gather.cuh | 53 +++ mlx/backend/cuda/device/gather_axis.cuh | 65 ++++ mlx/backend/cuda/device/indexing.cuh | 30 ++ mlx/backend/cuda/device/scatter.cuh | 68 ++++ mlx/backend/cuda/device/scatter_axis.cuh | 67 ++++ mlx/backend/cuda/device/scatter_ops.cuh | 44 +++ mlx/backend/cuda/indexing.cpp | 420 +++++++++++++++++++++++ mlx/backend/cuda/jit_module.cpp | 8 + mlx/backend/cuda/primitives.cu | 4 - 11 files changed, 830 insertions(+), 6 deletions(-) create mode 100644 mlx/backend/cuda/device/atomic_ops.cuh create mode 100644 mlx/backend/cuda/device/gather.cuh create mode 100644 mlx/backend/cuda/device/gather_axis.cuh create mode 100644 mlx/backend/cuda/device/indexing.cuh create mode 100644 mlx/backend/cuda/device/scatter.cuh create mode 100644 mlx/backend/cuda/device/scatter_axis.cuh create mode 100644 mlx/backend/cuda/device/scatter_ops.cuh create mode 100644 mlx/backend/cuda/indexing.cpp diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1567feafd..7cc74353a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -1,8 +1,8 @@ # Filename rules in cuda backend: # # * Use .cu/.cuh if code contains device code, and .cpp/.h if not. -# * Device-only kernel code should be put in kernels/ subdir. -# * Files in kernels/ subdir should not include files outside. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh new file mode 100644 index 000000000..b6915606e --- /dev/null +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -0,0 +1,72 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" + +#include + +namespace mlx::core::cu { + +template +inline __device__ void atomic_add(T* out, T val) { + cuda::atomic_ref ref(*out); + ref += val; +} + +template +inline __device__ void atomic_prod(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old * val)) { + } +} + +template +inline __device__ void atomic_max(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_max(val); +} + +template +inline __device__ void atomic_min(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_min(val); +} + +// Somehow cuda::atomic_ref does not provide atomic add for following types. +template +inline __device__ void atomic_add_general(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old + val)) { + } +} + +inline __device__ void atomic_add(__half* out, __half val) { + atomicAdd(out, val); +} + +inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +#if __CUDA_ARCH__ < 900 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { +#if __CUDA_ARCH__ < 800 +#if CCCL_VERSION >= 2008000 + atomic_add_general(out, val); +#else + bool cccl_version_too_old_for_bfloat16_atomic_add = false; + assert(cccl_version_too_old_for_bfloat16_atomic_add); +#endif +#else + atomicAdd(out, val); +#endif +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather.cuh b/mlx/backend/cuda/device/gather.cuh new file mode 100644 index 000000000..7dbd84ac3 --- /dev/null +++ b/mlx/backend/cuda/device/gather.cuh @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template +__global__ void gather( + const T* src, + T* out, + LocT size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + const __grid_constant__ Shape slice_sizes, + uint32_t slice_size, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT out_idx = cg::this_grid().thread_rank(); + if (out_idx >= size) { + return; + } + + LocT src_elem = out_idx % slice_size; + LocT idx_elem = out_idx / slice_size; + + LocT src_loc = + elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather_axis.cuh b/mlx/backend/cuda/device/gather_axis.cuh new file mode 100644 index 000000000..f863b2d95 --- /dev/null +++ b/mlx/backend/cuda/device/gather_axis.cuh @@ -0,0 +1,65 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + int NDIM, + bool SrcC, + bool IdxC, + typename LocT> +__global__ void gather_axis( + const T* src, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array src_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), src_strides.data()); + } + + LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/indexing.cuh b/mlx/backend/cuda/device/indexing.cuh new file mode 100644 index 000000000..31cba1a90 --- /dev/null +++ b/mlx/backend/cuda/device/indexing.cuh @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +namespace mlx::core::cu { + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __host__ __device__ cuda::std::tuple +index_to_dims(T index, T dim1, T dim2) { + T x = index / (dim1 * dim2); + T y = (index % (dim1 * dim2)) / dim2; + T z = index % dim2; + return cuda::std::make_tuple(x, y, z); +} + +// Get absolute index from possible negative index. +template +inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (cuda::std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh new file mode 100644 index 000000000..b2f640350 --- /dev/null +++ b/mlx/backend/cuda/device/scatter.cuh @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + int IDX_NDIM, + typename LocT> +__global__ void scatter( + const T* upd, + T* out, + LocT size, + const __grid_constant__ Shape upd_shape, + const __grid_constant__ Strides upd_strides, + int32_t upd_ndim, + LocT upd_post_idx_size, + const __grid_constant__ Shape out_shape, + const __grid_constant__ Strides out_strides, + int32_t out_ndim, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT upd_idx = cg::this_grid().thread_rank(); + if (upd_idx >= size) { + return; + } + + LocT out_elem = upd_idx % upd_post_idx_size; + LocT idx_elem = upd_idx / upd_post_idx_size; + + LocT out_idx = elem_to_loc( + out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); + out_idx += idx_val * out_strides[axis]; + } + + LocT upd_loc = elem_to_loc( + out_elem + idx_elem * upd_post_idx_size, + upd_shape.data(), + upd_strides.data(), + upd_ndim); + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_axis.cuh b/mlx/backend/cuda/device/scatter_axis.cuh new file mode 100644 index 000000000..1f30f2ebd --- /dev/null +++ b/mlx/backend/cuda/device/scatter_axis.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NDIM, + bool UpdC, + bool IdxC, + typename LocT> +__global__ void scatter_axis( + const T* upd, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array upd_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), upd_strides.data()); + } + + LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_ops.cuh b/mlx/backend/cuda/device/scatter_ops.cuh new file mode 100644 index 000000000..d88f896ad --- /dev/null +++ b/mlx/backend/cuda/device/scatter_ops.cuh @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/atomic_ops.cuh" + +namespace mlx::core::cu { + +struct ScatterAssign { + template + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +struct ScatterSum { + template + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +struct ScatterProd { + template + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +struct ScatterMax { + template + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +struct ScatterMin { + template + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp new file mode 100644 index 000000000..3603605c4 --- /dev/null +++ b/mlx/backend/cuda/indexing.cpp @@ -0,0 +1,420 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include "cuda_jit_sources.h" + +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; + +void append_indices_arg( + cu::JitModule& mod, + const std::vector& inputs, + int nidx, + int idx_ndim) { + std::vector indices(nidx); + for (int i = 0; i < nidx; ++i) { + indices[i] = inputs[i + 1].data(); + } + mod.append_arg(std::move(indices)); + std::vector indices_shape(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].shape().begin(), + idx_ndim, + indices_shape.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_shape)); + std::vector indices_strides(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].strides().begin(), + idx_ndim, + indices_strides.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_strides)); +} + +} // namespace + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 0); + const auto& src = inputs[0]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + int nidx = inputs.size() - 1; + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t slice_size = std::accumulate( + slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); + + std::string module_name = fmt::format( + "gather_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_gather, std::move(kernel_names)); + }); + + mod.append_arg(src); + mod.append_arg(out); + if (large) { + mod.append_arg(out.size()); + } else { + mod.append_arg(out.size()); + } + mod.append_ndim_arg(src.shape()); + mod.append_ndim_arg(src.strides()); + mod.append_arg(src.ndim()); + mod.append_ndim_arg(slice_sizes_); + mod.append_arg(slice_size); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, out, large); + }); +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 1); + auto& upd = inputs.back(); + + // Copy src into out. + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + int nidx = axes_.size(); + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t upd_post_idx_size = std::accumulate( + upd.shape().begin() + idx_ndim, + upd.shape().end(), + 1, + std::multiplies()); + + const char* op = g_scatter_ops[reduce_type_]; + std::string module_name = fmt::format( + "scatter_{}_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + op, + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_scatter, std::move(kernel_names)); + }); + + mod.append_arg(upd); + mod.append_arg(out); + if (large) { + mod.append_arg(upd.size()); + } else { + mod.append_arg(upd.size()); + } + mod.append_ndim_arg(upd.shape()); + mod.append_ndim_arg(upd.strides()); + mod.append_arg(upd.ndim()); + if (large) { + mod.append_arg(upd_post_idx_size); + } else { + mod.append_arg(upd_post_idx_size); + } + mod.append_ndim_arg(out.shape()); + mod.append_ndim_arg(out.strides()); + mod.append_arg(out.ndim()); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, upd, large); + }); +} + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherAxis::eval_gpu"); + assert(inputs.size() > 1); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + std::string module_name = fmt::format( + "gather_axis_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype())); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(src); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(src.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(src.shape(axis_)); + mod.append_arg(src.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + src.ndim() - 1, + src.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ScatterAxis::eval_gpu"); + assert(inputs.size() > 2); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + const auto& upd = inputs[2]; + + // Copy src into out. + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; + std::string module_name = fmt::format( + "scatter_axis_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype()), + op); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(upd); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(upd.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(out.shape(axis_)); + mod.append_arg(upd.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + idx.ndim() - 1, + upd.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 3c00dd7f0..b8be103cc 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -148,24 +148,32 @@ bool compiler_supports_device_sass(Device& device) { #define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "indexing.cuh", + INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", + INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", }; #undef INCLUDE_PREFIX constexpr const char* g_headers[] = { + jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, + jit_source_indexing, + jit_source_scatter_ops, jit_source_unary_ops, + jit_source_ternary_ops, jit_source_utils, }; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index eb451f49d..0c4d3a8aa 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -78,8 +78,6 @@ NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) -NO_GPU(Gather) -NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) @@ -89,8 +87,6 @@ NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) -NO_GPU(Scatter) -NO_GPU(ScatterAxis) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) From fddb6933e1cdcb268467fc5d02be6b471bb232b9 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 13 Jun 2025 10:44:56 -0700 Subject: [PATCH 101/156] Collection of refactors (#2274) * Refactor gemv into a function * Refactor splitk step 1 * Refactor split k axpby * Rearrange steel_gemm_regular * Redirect steel_gemm_regular * Add axpby routing to steel_matmul_regular * Refactor AddMM step 1 * Redirect steel_gemm * Update addmm * Comments and format * Some cleanup * Add architecture gen to device * Update no copy condition in normalization to account for axis size 1 --- mlx/backend/metal/conv.cpp | 40 +- mlx/backend/metal/device.cpp | 3 + mlx/backend/metal/device.h | 5 + .../steel/gemm/kernels/steel_gemm_fused.h | 4 +- mlx/backend/metal/matmul.cpp | 1202 ++++++++--------- mlx/backend/metal/matmul.h | 104 +- mlx/backend/metal/normalization.cpp | 4 +- 7 files changed, 720 insertions(+), 642 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 697afa6a1..9eb6a6385 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu( // Perform gemm std::vector copies = {in_unfolded, wt_transpose}; return steel_matmul_regular( - s, - d, - /* a = */ in_unfolded, - /* b = */ wt_transpose, - /* c = */ out, - /* M = */ implicit_M, - /* N = */ implicit_N, - /* K = */ implicit_K, - /* batch_size_out = */ groups, - /* a_cols = */ implicit_K * groups, - /* b_cols = */ implicit_K, - /* out_cols = */ implicit_N * groups, - /* a_transposed = */ false, - /* b_transposed = */ true, - /* batch_shape = */ {1}, - /* batch_strides = */ {0}, - /* A_batch_strides = */ size_t(implicit_K), - /* B_batch_strides = */ size_t(implicit_N) * implicit_K, - /* matrix_stride_out = */ size_t(implicit_N), - /*copies = */ copies); + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out, + /* int M = */ implicit_M, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 425274361..88835eb75 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -297,6 +297,9 @@ Device::Device() { device_ = load_device(); default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); + int ag_tens = arch_[arch_.size() - 3] - '0'; + int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { case 'p': // phone diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5bfcc6649..f87a8c48b 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -177,6 +177,10 @@ class Device { return arch_; } + int get_architecture_gen() const { + return arch_gen_; + } + void new_queue(int index); MTL::CommandQueue* get_queue(Stream stream); @@ -268,6 +272,7 @@ class Device { library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int arch_gen_; int max_ops_per_buffer_; int max_mb_per_buffer_; }; diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index add495d93..85830872d 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index ed96d37ea..be7f3e2f8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -164,11 +164,17 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { wn = 2; \ } -void steel_matmul_regular( +/////////////////////////////////////////////////////////////////////////////// +// Regular steel matmul dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_regular_axpby( const Stream& s, metal::Device& d, const array& a, const array& b, + const array& c, array& out, int M, int N, @@ -179,12 +185,15 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, - std::vector& copies) { + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { using namespace mlx::steel; // Determine dispatch kernel @@ -196,16 +205,21 @@ void steel_matmul_regular( // Prepare kernel name std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; + + // clang-format off + kname << "steel_gemm_fused_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -232,18 +246,18 @@ void steel_matmul_regular( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); compute_encoder.set_compute_pipeline_state(kernel); @@ -286,8 +300,25 @@ void steel_matmul_regular( compute_encoder.set_bytes(params, 4); - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -295,7 +326,437 @@ void steel_matmul_regular( d.add_temporaries(std::move(copies), s.index); } -void steel_matmul( +/////////////////////////////////////////////////////////////////////////////// +// Split k steel matmul +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_gemm_splitk_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { + using namespace mlx::steel; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc(C_split.nbytes())); + copies.push_back(C_split); + + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_splitk_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(C_split) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn + << "_MN_" << (mn_aligned ? "t" : "n") << "aligned" + << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on + + // Encode and dispatch gemm kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_splitk_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kname.str(), + /* const array& in = */ a, + /* const array& out = */ C_split, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn, + /* bool mn_aligned = */ mn_aligned, + /* bool k_aligned = */ k_aligned); + + compute_encoder.set_compute_pipeline_state(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); + + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Do accum kernel + { + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); + + if (do_axpby) { + kernel_name = kernel_name + "_axbpy"; + } + + auto kernel = get_steel_gemm_splitk_accum_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ kernel_name, + /* const array& in = */ C_split, + /* const array& out = */ out, + /* bool axbpy = */ do_axpby); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set the arguments for the kernel + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); + + if (do_axpby) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + compute_encoder.set_input_array(c, 5); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha, 8); + compute_encoder.set_bytes(beta, 9); + } + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + auto group_dims = get_block_dims(N, M, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + + d.add_temporaries(std::move(copies), s.index); +} + +/////////////////////////////////////////////////////////////////////////////// +// Split matmul routing +/////////////////////////////////////////////////////////////////////////////// + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape /* = {} */, + Strides A_batch_stride /* = {} */, + Strides B_batch_stride /* = {} */, + Strides C_batch_stride /* = {} */, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + if (batch_shape.empty()) { + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + if constexpr (CHECK_AB) { + auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] = + collapse_batches(a, b, c); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + C_batch_stride = C_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + } else { + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + return steel_gemm_splitk_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + auto batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + if (CHECK_AB && !C_batch_stride.empty()) { + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); + } + + int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back(); + int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back(); + int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back(); + + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ N, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides batch_strides = */ std::move(batch_strides), + /* int64_t A_batch_stride = */ A_batch_stride_, + /* int64_t B_batch_stride = */ B_batch_stride_, + /* int64_t matrix_stride_out = */ int64_t(M) * N, + /* int64_t C_batch_stride = */ C_batch_stride_, + /* float alpha = */ alpha, + /* float beta = */ beta); +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMV dispatch +/////////////////////////////////////////////////////////////////////////////// + +template +void gemv_axbpy( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? ldb : lda; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { + bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + sn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv_" << type_to_name(out); + } + + const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); + + // clang-format off + kname << "_bm" << bm << "_bn" << bn + << "_sm" << sm << "_sn" << sn + << "_tm" << tm << "_tn" << tn + << "_nc" << !contiguous_kernel + << "_axpby" << do_axpby; // clang-format on + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder.set_compute_pipeline_state(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + + if (do_axpby) { + compute_encoder.set_input_array(c, 2); + + compute_encoder.set_bytes(alpha, 7); + compute_encoder.set_bytes(beta, 8); + + compute_encoder.set_vector_bytes(C_batch_stride, 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder.set_bytes(bias_stride, 14); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + +inline void gemv( const Stream& s, metal::Device& d, const array& a, @@ -310,166 +771,34 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - Shape batch_shape /* = {} */, - Strides A_batch_stride /* = {} */, - Strides B_batch_stride /* = {} */) { - using namespace mlx::steel; - - if (batch_shape.empty()) { - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); - - batch_shape = batch_shape_; - A_batch_stride = A_bstride_; - B_batch_stride = B_bstride_; - // Collapse batches into M if needed - if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && - B_batch_stride.back() == 0) { - M *= batch_shape.back(); - batch_size_out = 1; - - A_batch_stride = {0}; - B_batch_stride = {0}; - batch_shape = {1}; - } - } - - size_t matrix_stride_out = size_t(M) * N; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldc = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int split_k_partitions = */ split_k_partitions, - /* const int split_k_partition_stride = */ split_k_partition_stride, - /* const int split_k_partition_size = */ split_k_partition_size, - /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split); - - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, false); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch - auto batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - - steel_matmul_regular( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - lda, - ldb, - N, - transpose_a, - transpose_b, - std::move(batch_shape), - std::move(batch_strides), - A_batch_stride.back(), - B_batch_stride.back(), - matrix_stride_out, - copies); + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } +/////////////////////////////////////////////////////////////////////////////// +// Matmul implementation +/////////////////////////////////////////////////////////////////////////////// + void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (!issubdtype(out.dtype(), floating)) { @@ -528,102 +857,26 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby0"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } + ///////////////////////////////////////////////////////////////////////////// // Gemm specialization @@ -641,12 +894,16 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* int ldb = */ b_cols, /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, - /* std::vector& = */ copies, - /* Shape batch_shape = */ batch_shape, - /* Strides A_batch_stride = */ A_batch_stride, - /* Strides B_batch_stride = */ B_batch_stride); + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ std::move(batch_shape), + /* Strides A_batch_stride = */ std::move(A_batch_stride), + /* Strides B_batch_stride = */ std::move(B_batch_stride)); } +/////////////////////////////////////////////////////////////////////////////// +// AddMM implementation +/////////////////////////////////////////////////////////////////////////////// + void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); if (!issubdtype(out.dtype(), floating)) { @@ -726,346 +983,61 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Route to gemv if needed if (std::min(M, N) == 1) { - // Collect problem info - bool is_b_matrix = N != 1; - - auto& mat = is_b_matrix ? b : a; - auto& vec = is_b_matrix ? a : b; - bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; - int in_vector_len = K; - int out_vector_len = is_b_matrix ? N : M; - - int mat_cols = transpose_mat ? out_vector_len : in_vector_len; - int mat_rows = transpose_mat ? in_vector_len : out_vector_len; - int mat_ld = is_b_matrix ? b_cols : a_cols; - - auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; - auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - - int stride_mat = batch_strides_mat.back(); - int stride_vec = batch_strides_vec.back(); - - // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = (batch_shape.size() == 1); - - int batch_ndim = batch_shape.size(); - - // Determine dispatch kernel - int tm = 4, tn = 4; - int sm = 1, sn = 32; - int bm = 1, bn = 1; - int n_out_per_tgp; - std::ostringstream kname; - - if (transpose_mat) { - if (in_vector_len >= 8192 && out_vector_len >= 2048) { - sm = 4; - sn = 8; - } else { - sm = 8; - sn = 4; - } - - if (out_vector_len >= 2048) { - bn = 16; - } else if (out_vector_len >= 512) { - bn = 4; - } else { - bn = 2; - } - - // Specialized kernel for very small outputs - tn = out_vector_len < tn ? 1 : tn; - - n_out_per_tgp = bn * sn * tn; - kname << "gemv_t_" << type_to_name(out); - - } else { - bm = out_vector_len >= 4096 ? 8 : 4; - sn = 32; - - // Specialized kernel for very small outputs - tm = out_vector_len < tm ? 1 : tm; - - n_out_per_tgp = bm * sm * tm; - kname << "gemv_" << type_to_name(out); - } - - kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" - << tm << "_tn" << tn; - kname << "_nc" << !contiguous_kernel << "_axpby1"; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(32, bn, bm); - MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - - compute_encoder.set_input_array(mat, 0); - compute_encoder.set_input_array(vec, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(in_vector_len, 4); - compute_encoder.set_bytes(out_vector_len, 5); - compute_encoder.set_bytes(mat_ld, 6); - - compute_encoder.set_bytes(alpha_, 7); - compute_encoder.set_bytes(beta_, 8); - - compute_encoder.set_bytes(batch_ndim, 9); - compute_encoder.set_vector_bytes(batch_shape, 10); - compute_encoder.set_vector_bytes(batch_strides_vec, 11); - compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.set_vector_bytes(C_batch_stride, 13); - - int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder.set_bytes(bias_stride, 14); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; - } - - using namespace mlx::steel; - - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - - int _tm = M / 16; - int _tn = N / 16; - int _tk = K / 16; - - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { - int bm = M < 40 ? 16 : 32; - int bn = N < 40 ? 16 : 32; - int bk = 16; - int wm = 2, wn = 2; - - int split_k_partitions = - _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); - int split_k_partition_stride = M * N; - int gemm_k_iterations = (K / bk) / split_k_partitions; - int split_k_partition_size = gemm_k_iterations * bk; - - array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc(C_split.nbytes())); - copies.push_back(C_split); - - bool mn_aligned = M % bm == 0 && N % bn == 0; - bool k_aligned = K % bk == 0; - - std::ostringstream kname; - kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") - << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; - - // Encode and dispatch gemm kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_splitk_kernel( - d, - kname.str(), - a, - C_split, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn, - mn_aligned, - k_aligned); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(C_split, 2); - - compute_encoder.set_bytes(params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - // Do accum kernel - { - auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axbpy"; - auto kernel = get_steel_gemm_splitk_accum_kernel( - d, kernel_name, C_split, out, true); - - compute_encoder.set_compute_pipeline_state(kernel); - - // Set the arguments for the kernel - compute_encoder.set_input_array(C_split, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(split_k_partitions, 2); - compute_encoder.set_bytes(split_k_partition_stride, 3); - compute_encoder.set_bytes(N, 4); - compute_encoder.set_input_array(c, 5); - compute_encoder.set_bytes(ldc, 6); - compute_encoder.set_bytes(fdc, 7); - compute_encoder.set_bytes(alpha_, 8); - compute_encoder.set_bytes(beta_, 9); - - // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(N, M, 1); - auto group_dims = get_block_dims(N, M, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - } - - d.add_temporaries(std::move(copies), s.index); - return; + return gemv_axbpy( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } ///////////////////////////////////////////////////////////////////////////// // Regular addmm dispatch - // Determine dispatch kernel - int bm = 64, bn = 64, bk = 16; - int wm = 2, wn = 2; - - char devc = d.get_architecture().back(); - GEMM_TPARAM_MACRO(devc) - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = true; - const bool do_axpby = !(alpha_ == 1. && beta_ == 1.); - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder.set_compute_pipeline_state(kernel); - - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams gemm_params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const int64_t batch_stride_a = */ A_batch_stride.back(), - /* const int64_t batch_stride_b = */ B_batch_stride.back(), - /* const int64_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - GEMMAddMMParams params{ - /* const int ldc = */ ldc, - /* const int fdc = */ fdc, - /* const int64_t batch_stride_c = */ C_batch_stride.back(), - /* const float alpha = */ alpha_, - /* const float beta = */ beta_}; - - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - - Strides batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - batch_strides.insert( - batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - compute_encoder.set_bytes(gemm_params, 4); - compute_encoder.set_bytes(params, 5); - - compute_encoder.set_vector_bytes(batch_shape, 6); - compute_encoder.set_vector_bytes(batch_strides, 7); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride, + /* Strides B_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha_, + /* float beta = */ beta_); } +/////////////////////////////////////////////////////////////////////////////// +// BlockMaskedMM implementation +/////////////////////////////////////////////////////////////////////////////// + void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { using namespace mlx::steel; // assert(inputs.size() == 2); @@ -1454,6 +1426,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { d.add_temporaries(std::move(copies), s.index); } +/////////////////////////////////////////////////////////////////////////////// +// GatherMM implementation +/////////////////////////////////////////////////////////////////////////////// + void gather_mm_rhs( const array& a_, const array& b_, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 09ffe05a8..218664b1f 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -6,7 +6,34 @@ namespace mlx::core { -void steel_matmul_regular( +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -21,14 +48,61 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, + std::vector& copies, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, - int64_t matrix_stride_out, - std::vector& copies); + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} -void steel_matmul( +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -45,6 +119,26 @@ void steel_matmul( std::vector& copies, Shape batch_shape = {}, Strides A_batch_stride = {}, - Strides B_batch_stride = {}); + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index d570bf3c0..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -26,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -227,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { From 8402a2acf4325a7213211dd7fcb4f397981ca695 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 11:13:00 -0700 Subject: [PATCH 102/156] Fix complex power and print (#2286) * fix complex power and print * fix complex matmul shape --- mlx/backend/cuda/device/binary_ops.cuh | 7 ++++ mlx/backend/metal/kernels/binary_ops.h | 7 ++++ mlx/ops.cpp | 51 +++++++++++++------------- mlx/utils.cpp | 7 +++- python/tests/test_blas.py | 10 +++++ python/tests/test_ops.py | 7 ++++ 6 files changed, 63 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 4779a6f33..b96a7f9cc 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -194,6 +194,13 @@ struct Power { } return res; } else if constexpr (cuda::std::is_same_v) { + if (base.y == 0 && base.x == 0) { + if (isnan(exp.x) || isnan(exp.y)) { + auto nan = cuda::std::numeric_limits::quiet_NaN(); + return make_cuFloatComplex(nan, nan); + } + return make_cuFloatComplex(0.0, 0.0); + } auto x_theta = atan2f(base.y, base.x); auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 4aaf2b4da..f4deb860e 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -235,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9602f667a..2b861428f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2847,21 +2847,6 @@ array matmul( "[matmul] Got 0 dimension input. Inputs must " "have at least one dimension."); } - if (a.ndim() == 1) { - // Insert a singleton dim in the beginning - a = expand_dims(a, 0, s); - } - if (b.ndim() == 1) { - // Insert a singleton dim at the end - b = expand_dims(b, 1, s); - } - if (a.shape(-1) != b.shape(-2)) { - std::ostringstream msg; - msg << "[matmul] Last dimension of first input with shape " << a.shape() - << " must match second to last dimension of" - << " second input with shape " << b.shape() << "."; - throw std::invalid_argument(msg.str()); - } // complex matmul using Karatsuba's Algorithm if (a.dtype() == complex64 || b.dtype() == complex64) { @@ -2883,6 +2868,22 @@ array matmul( c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = expand_dims(a, 0, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = expand_dims(b, 1, s); + } + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[matmul] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); @@ -4240,6 +4241,16 @@ array addmm( "have at least one dimension."); } + // Type promotion + auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -4257,16 +4268,6 @@ array addmm( throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = result_type(a, b, c); - - if (out_type == complex64) { - return add( - multiply(matmul(a, b, s), array(alpha), s), - multiply(array(beta), c, s), - s); - } - if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 0b2e66352..61b9da3a2 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) { os << val; } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val; + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } PrintFormatter& get_global_formatter() { diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8c7a97ba8..2762df8f8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase): c_np = np.matmul(np.array(a).T, b) self.assertTrue(np.allclose(c, c_np)) + # Check shapes + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + self.assertEqual((a @ b).shape, (2,)) + + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + c = mx.random.normal((2,)) + self.assertEqual(mx.addmm(c, a, b).shape, (2,)) + def test_complex_gemm(self): M = 16 K = 50 diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f3d48dda3..7c4f3f8e3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase): ) self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + def test_complex_power(self): + out = mx.power(mx.array(0j), 2) + self.assertEqual(out.item(), 0j) + + out = mx.power(mx.array(0j), float("nan")) + self.assertTrue(mx.isnan(out)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self): From 6871e2eeb7f058d4a4dd1c76cb0464fd51aa11ba Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 19:21:46 -0700 Subject: [PATCH 103/156] fix cuda jit (#2287) --- mlx/backend/cuda/jit_module.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index b8be103cc..8a033523c 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) { } } -#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" +#define INCLUDE_PREFIX "mlx/backend/cuda/device/" constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", From a6d780154f2fe79e893045659d17fbace243802a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 22:10:46 -0700 Subject: [PATCH 104/156] fix cuda gemm for bf16 (#2288) --- mlx/backend/cuda/matmul.cpp | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 89247fd3e..9930c75b8 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -44,9 +44,12 @@ class MatMul { int64_t b_batch_stride) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - auto type = dtype_to_cuda_type(dtype); + auto scale_type = dtype_to_cuda_type(dtype); + if (dtype == bfloat16) { + scale_type = CUDA_R_32F; + } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), type)); + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -65,6 +68,7 @@ class MatMul { &op, sizeof(cublasOperation_t))); + auto type = dtype_to_cuda_type(dtype); a_desc_ = create_matrix_layout( type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); b_desc_ = create_matrix_layout( @@ -187,15 +191,10 @@ class MatMul { private: cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { - case uint8: - case uint16: - case int8: - case int16: - case int32: - return CUBLAS_COMPUTE_32I; case float16: - case bfloat16: return CUBLAS_COMPUTE_16F; + case bfloat16: + return CUBLAS_COMPUTE_32F; case float32: return CUBLAS_COMPUTE_32F; case float64: @@ -209,16 +208,6 @@ class MatMul { cudaDataType_t dtype_to_cuda_type(Dtype dtype) { switch (dtype) { - case uint8: - return CUDA_R_8U; - case uint16: - return CUDA_R_16U; - case int8: - return CUDA_R_8I; - case int16: - return CUDA_R_16I; - case int32: - return CUDA_R_32I; case float16: return CUDA_R_16F; case bfloat16: From a14aaa7c9d2b1cafbc73e61f993ebbeb81faf4d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 14 Jun 2025 17:54:00 -0700 Subject: [PATCH 105/156] Fix cuda arg reduce (#2291) --- mlx/backend/cuda/arg_reduce.cu | 5 ++--- mlx/backend/cuda/matmul.cpp | 8 +++++--- mlx/utils.h | 5 +++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 7dbd91e46..c8a5a962a 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,5 +1,4 @@ // Copyright © 2025 Apple Inc. - #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/iterators/strided_iterator.cuh" @@ -113,7 +112,7 @@ __global__ void arg_reduce_general( for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { T vals[N_READS]; - auto tid = r * BLOCK_DIM + block.thread_index().z; + auto tid = r * BLOCK_DIM + block.thread_index().x; cub::LoadDirectBlocked( tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); best = op.reduce_many(best, vals, tid * N_READS); @@ -158,7 +157,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{1, 1, BLOCK_DIM}; + dim3 block_dims{BLOCK_DIM, 1, 1}; auto kernel = &cu::arg_reduce_general< InType, cu::ArgMax, diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 9930c75b8..5a5e6182e 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include "mlx/utils.h" #include #include @@ -45,7 +46,7 @@ class MatMul { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); - if (dtype == bfloat16) { + if (dtype == bfloat16 || dtype == float16) { scale_type = CUDA_R_32F; } CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( @@ -192,11 +193,12 @@ class MatMul { cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { case float16: - return CUBLAS_COMPUTE_16F; + return CUBLAS_COMPUTE_32F; case bfloat16: return CUBLAS_COMPUTE_32F; case float32: - return CUBLAS_COMPUTE_32F; + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; case float64: case complex64: return CUBLAS_COMPUTE_64F; diff --git a/mlx/utils.h b/mlx/utils.h index f0aa7c2de..f16bf0468 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -149,6 +149,11 @@ inline bool metal_fast_synch() { return metal_fast_synch; } +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + } // namespace env } // namespace mlx::core From 580776559be625d2149ae13bab61bf325864cc1b Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 15 Jun 2025 06:08:07 -0700 Subject: [PATCH 106/156] RoPE for CUDA (#2293) * First working CUDA rope * Fix random --- mlx/backend/common/utils.cpp | 10 + mlx/backend/common/utils.h | 3 + mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/kernel_utils.cu | 7 + mlx/backend/cuda/kernel_utils.cuh | 1 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/random.cu | 64 ++--- mlx/backend/cuda/rope.cu | 385 ++++++++++++++++++++++++++++++ 8 files changed, 443 insertions(+), 29 deletions(-) create mode 100644 mlx/backend/cuda/rope.cu diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 08df53a8e..457ecb7f7 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -209,4 +209,14 @@ Dims get_2d_grid_dims_common( static_cast(grid_x), static_cast(grid_y), 1); } +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { + auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2); + auto gx = (dim0 + bx - 1) / bx; + auto gy = (dim1 + by - 1) / by; + auto gz = (dim2 + bz - 1) / bz; + + return std::make_pair( + std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 40bc3efe4..114878846 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -95,6 +95,9 @@ Dims get_2d_grid_dims_common( const Strides& strides, size_t divisor); +// Get both the block and a grid of blocks that covers dim0, dim1 and dim2. +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); + struct ContiguousIterator { inline void step() { int dims = shape_.size(); diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 7cc74353a..d96bb8812 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -32,6 +32,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu index 575af7cf6..7b87aa5b0 100644 --- a/mlx/backend/cuda/kernel_utils.cu +++ b/mlx/backend/cuda/kernel_utils.cu @@ -23,4 +23,11 @@ dim3 get_2d_grid_dims( return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); } +std::pair get_grid_and_block(int dim0, int dim1, int dim2) { + auto [grid, block] = get_grid_and_block_common(dim0, dim1, dim2); + auto [gx, gy, gz] = grid; + auto [bx, by, bz] = block; + return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 7e957bbbd..84392a1ec 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -121,6 +121,7 @@ dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor); +std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Return a block size that achieves maximum potential occupancy for kernel. template diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 0c4d3a8aa..c2362bea2 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -94,7 +94,6 @@ NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { -NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index d2b1b7dd5..0cb550d56 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/primitives.h" +#include #include #include @@ -12,6 +13,8 @@ namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + __constant__ constexpr uint32_t rotations[2][4] = { {13, 15, 26, 6}, {17, 29, 16, 24}}; @@ -47,27 +50,28 @@ __global__ void rbitsc( dim3 grid_dims, bool odd, uint32_t bytes_per_key) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto key = uint2{keys[kidx], keys[kidx + 1]}; auto half_size = grid_dims.y - odd; - out += index.x * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += index_x * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -89,30 +93,31 @@ __global__ void rbits( int32_t ndim, const __grid_constant__ Shape key_shape, const __grid_constant__ Strides key_strides) { - uint2 index{ - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y}; - if (index.x >= grid_dims.x || index.y >= grid_dims.y) { + auto grid = cg::this_grid(); + uint thread_index = grid.thread_rank(); + uint index_x = thread_index % grid_dims.x; + uint index_y = thread_index / grid_dims.x; + if (index_x >= grid_dims.x || index_y >= grid_dims.y) { return; } - auto kidx = 2 * index.x; + auto kidx = 2 * index_x; auto k1_elem = elem_to_loc(kidx, key_shape.data(), key_strides.data(), ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape.data(), key_strides.data(), ndim); auto key = uint2{keys[k1_elem], keys[k2_elem]}; auto half_size = grid_dims.y - odd; - out += size_t(index.x) * bytes_per_key; - bool drop_last = odd && (index.y == half_size); + out += size_t(index_x) * bytes_per_key; + bool drop_last = odd && (index_y == half_size); auto bits = threefry2x32_hash( - key, uint2{index.y, drop_last ? 0 : index.y + grid_dims.y}); - size_t idx = size_t(index.y) << 2; + key, uint2{index_y, drop_last ? 0 : index_y + grid_dims.y}); + size_t idx = size_t(index_y) << 2; for (int i = 0; i < 4; ++i) { out[idx + i] = bits.bytes[0][i]; } if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dims.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims.y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { int edge_bytes = (bytes_per_key % 4); for (int i = 0; i < edge_bytes; ++i) { out[idx + i] = bits.bytes[1][i]; @@ -153,19 +158,22 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dim3 grid_dims{num_keys, half_size + odd}; - dim3 block_dims = get_block_dims(grid_dims.x, grid_dims.y, 1); - dim3 num_blocks{ - cuda::ceil_div(grid_dims.x, block_dims.x), - cuda::ceil_div(grid_dims.y, block_dims.y)}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( + cu::rbitsc<<>>( keys.data(), out.data(), grid_dims, odd, bytes_per_key); } else { - cu::rbits<<>>( + cu::rbits<<>>( keys.data(), out.data(), grid_dims, diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu new file mode 100644 index 000000000..1d8307811 --- /dev/null +++ b/mlx/backend/cuda/rope.cu @@ -0,0 +1,385 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace cu { + +template +__device__ void rope_single_impl( + const T* in, + T* out, + int32_t offset, + float inv_freq, + float scale, + int64_t stride, + uint2 pos, + uint2 dims) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + uint index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + dims.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint2 dims) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__global__ void rope_single_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint2 dims, + int64_t freq_stride) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +template +__device__ void rope_impl( + const T* in, + T* out, + int offset, + float inv_freq, + float scale, + const cuda::std::array strides, + const cuda::std::array out_strides, + int64_t n_batch, + uint3 pos, + uint3 dims) { + float L = scale * static_cast(pos.y + offset); + + // Compute costheta, sintheta + float theta = L * inv_freq; + float costheta = cos(theta); + float sintheta = sin(theta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + N * pos.z * out_strides[0]; + out_index_2 = out_index_1 + dims.x * out_strides[2]; + in_index_1 = + pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_2 = in_index_1 + dims.x * strides[2]; + } + for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +__global__ void rope( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2(-d * base); + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +template +__global__ void rope_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + float base, + const __grid_constant__ cuda::std::array strides, + const __grid_constant__ cuda::std::array out_strides, + int64_t n_batch, + uint3 dims, + int64_t freq_stride) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float inv_freq = 1.0 / freqs[freq_stride * pos.x]; + rope_impl( + in, + out, + *offset, + inv_freq, + scale, + strides, + out_strides, + n_batch, + pos, + dims); +} + +} // namespace cu + +namespace fast { + +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("RoPE::eval_gpu"); + + auto& s = stream(); + auto& in = inputs[0]; + auto& offset = inputs[1]; + auto& out = outputs[0]; + + if (in.ndim() < 3) { + throw std::runtime_error("[RoPE] Input must have at least 3 dimensions"); + } + + cuda::std::array strides; + cuda::std::array out_strides; + bool donated = false; + int ndim = in.ndim(); + int dispatch_ndim = in.ndim(); + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } + size_t mat_size = in.shape(-2) * in.shape(-1); + + // We apply rope to less that the whole vector so copy to output and then + // apply in-place. + if (dims_ < in.shape(-1)) { + donated = true; + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + + // Either copy or apply in-place + else if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (dispatch_ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(allocator::malloc(out.nbytes())); + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + + // Some flags to help us dispatch below + bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); + bool with_freqs = inputs.size() == 3; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(donated ? out : in); + encoder.set_input_array(offset); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { + using DataType = cuda_type_t; + MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { + MLX_SWITCH_BOOL(forward_, FORWARD, { + if (single && !with_freqs) { + auto kernel = cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + kernel<<>>( + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } + }); + }); + }); + }); +} + +} // namespace fast + +} // namespace mlx::core From 4fda5fbdf94e70eb467b8f0a4900bfdf5f8ce108 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 15 Jun 2025 10:56:48 -0700 Subject: [PATCH 107/156] add python testing for cuda with ability to skip list of tests (#2295) --- .circleci/config.yml | 1 + python/tests/__main__.py | 5 + python/tests/cuda_skip.py | 143 ++++++++++++++++++++++++++ python/tests/mlx_tests.py | 36 +++++++ python/tests/ring_test_distributed.py | 2 +- python/tests/test_array.py | 2 +- python/tests/test_autograd.py | 2 +- python/tests/test_bf16.py | 2 +- python/tests/test_blas.py | 2 +- python/tests/test_compile.py | 2 +- python/tests/test_constants.py | 2 +- python/tests/test_conv.py | 2 +- python/tests/test_conv_transpose.py | 2 +- python/tests/test_device.py | 4 +- python/tests/test_double.py | 2 +- python/tests/test_einsum.py | 2 +- python/tests/test_eval.py | 4 +- python/tests/test_export_import.py | 2 +- python/tests/test_fast.py | 2 +- python/tests/test_fast_sdpa.py | 4 +- python/tests/test_fft.py | 2 +- python/tests/test_graph.py | 2 +- python/tests/test_init.py | 2 +- python/tests/test_linalg.py | 2 +- python/tests/test_load.py | 2 +- python/tests/test_losses.py | 2 +- python/tests/test_memory.py | 2 +- python/tests/test_nn.py | 2 +- python/tests/test_ops.py | 2 +- python/tests/test_optimizers.py | 2 +- python/tests/test_quantized.py | 2 +- python/tests/test_random.py | 2 +- python/tests/test_reduce.py | 2 +- python/tests/test_tree.py | 2 +- python/tests/test_upsample.py | 2 +- python/tests/test_vmap.py | 2 +- 36 files changed, 220 insertions(+), 35 deletions(-) create mode 100644 python/tests/__main__.py create mode 100644 python/tests/cuda_skip.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 808242f9b..0ea9303db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -234,6 +234,7 @@ jobs: command: | source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v + LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v build_release: parameters: diff --git a/python/tests/__main__.py b/python/tests/__main__.py new file mode 100644 index 000000000..5230bd535 --- /dev/null +++ b/python/tests/__main__.py @@ -0,0 +1,5 @@ +from . import mlx_tests + +__unittest = True + +mlx_tests.MLXTestRunner(module=None) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py new file mode 100644 index 000000000..cda396dcb --- /dev/null +++ b/python/tests/cuda_skip.py @@ -0,0 +1,143 @@ +cuda_skip = { + "TestArray.test_api", + "TestArray.test_setitem", + "TestAutograd.test_cumprod_grad", + "TestAutograd.test_slice_grads", + "TestAutograd.test_split_against_slice", + "TestAutograd.test_stop_gradient", + "TestAutograd.test_topk_grad", + "TestAutograd.test_update_state", + "TestAutograd.test_vjp", + "TestBF16.test_arg_reduction_ops", + "TestBF16.test_binary_ops", + "TestBF16.test_reduction_ops", + "TestBlas.test_block_masked_matmul", + "TestBlas.test_complex_gemm", + "TestBlas.test_gather_matmul", + "TestBlas.test_gather_matmul_grad", + "TestBlas.test_matmul_batched", + "TestBlas.test_matrix_vector_attn", + "TestCompile.test_compile_dynamic_dims", + "TestCompile.test_compile_inf", + "TestCompile.test_inf_constant", + "TestConv.test_1d_conv_with_2d", + "TestConv.test_asymmetric_padding", + "TestConv.test_basic_grad_shapes", + "TestConv.test_conv2d_unaligned_channels", + "TestConv.test_conv_1d_groups_flipped", + "TestConv.test_conv_general_flip_grad", + "TestConv.test_conv_groups_grad", + "TestConv.test_numpy_conv", + "TestConv.test_repeated_conv", + "TestConv.test_torch_conv_1D", + "TestConv.test_torch_conv_1D_grad", + "TestConv.test_torch_conv_2D", + "TestConv.test_torch_conv_2D_grad", + "TestConv.test_torch_conv_3D", + "TestConv.test_torch_conv_3D_grad", + "TestConv.test_torch_conv_depthwise", + "TestConv.test_torch_conv_general", + "TestConvTranspose.test_torch_conv_tranpose_1d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_1D", + "TestConvTranspose.test_torch_conv_transpose_1D_grad", + "TestConvTranspose.test_torch_conv_transpose_2D", + "TestConvTranspose.test_torch_conv_transpose_2D_grad", + "TestConvTranspose.test_torch_conv_transpose_2d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_3D", + "TestConvTranspose.test_torch_conv_transpose_3D_grad", + "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", + "TestEinsum.test_attention", + "TestEinsum.test_ellipses", + "TestEinsum.test_opt_einsum_test_cases", + "TestEval.test_multi_output_eval_during_transform", + "TestExportImport.test_export_conv", + "TestFast.test_rope_grad", + "TestFFT.test_fft", + "TestFFT.test_fft_big_powers_of_two", + "TestFFT.test_fft_contiguity", + "TestFFT.test_fft_exhaustive", + "TestFFT.test_fft_grads", + "TestFFT.test_fft_into_ifft", + "TestFFT.test_fft_large_numbers", + "TestFFT.test_fft_shared_mem", + "TestFFT.test_fftn", + "TestInit.test_orthogonal", + "TestLinalg.test_cholesky", + "TestLinalg.test_cholesky_inv", + "TestLinalg.test_eig", + "TestLinalg.test_eigh", + "TestLinalg.test_inverse", + "TestLinalg.test_lu", + "TestLinalg.test_lu_factor", + "TestLinalg.test_pseudo_inverse", + "TestLinalg.test_qr_factorization", + "TestLinalg.test_svd_decomposition", + "TestLinalg.test_tri_inverse", + "TestLoad.test_load_f8_e4m3", + "TestLosses.test_binary_cross_entropy", + "TestMemory.test_memory_info", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestLayers.test_elu", + "TestLayers.test_group_norm", + "TestLayers.test_hard_shrink", + "TestLayers.test_pooling", + "TestLayers.test_quantized_embedding", + "TestLayers.test_sin_pe", + "TestLayers.test_softshrink", + "TestLayers.test_upsample", + "TestOps.test_argpartition", + "TestOps.test_array_equal", + "TestOps.test_as_strided", + "TestOps.test_atleast_1d", + "TestOps.test_atleast_2d", + "TestOps.test_atleast_3d", + "TestOps.test_binary_ops", + "TestOps.test_bitwise_grad", + "TestOps.test_complex_ops", + "TestOps.test_divmod", + "TestOps.test_dynamic_slicing", + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + "TestOps.test_irregular_binary_ops", + "TestOps.test_isfinite", + "TestOps.test_kron", + "TestOps.test_log", + "TestOps.test_log10", + "TestOps.test_log1p", + "TestOps.test_log2", + "TestOps.test_logaddexp", + "TestOps.test_logcumsumexp", + "TestOps.test_partition", + "TestOps.test_scans", + "TestOps.test_slice_update_reversed", + "TestOps.test_softmax", + "TestOps.test_sort", + "TestOps.test_tensordot", + "TestOps.test_tile", + "TestOps.test_view", + "TestQuantized.test_gather_matmul_grad", + "TestQuantized.test_gather_qmm", + "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_non_multiples", + "TestQuantized.test_qmm", + "TestQuantized.test_qmm_jvp", + "TestQuantized.test_qmm_shapes", + "TestQuantized.test_qmm_vjp", + "TestQuantized.test_qmv", + "TestQuantized.test_quantize_dequantize", + "TestQuantized.test_qvm", + "TestQuantized.test_qvm_splitk", + "TestQuantized.test_small_matrix", + "TestQuantized.test_throw", + "TestQuantized.test_vjp_scales_biases", + "TestReduce.test_axis_permutation_sums", + "TestReduce.test_dtypes", + "TestReduce.test_expand_sums", + "TestReduce.test_many_reduction_axes", + "TestUpsample.test_torch_upsample", + "TestVmap.test_unary", + "TestVmap.test_vmap_conv", + "TestVmap.test_vmap_inverse", + "TestVmap.test_vmap_svd", +} diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index f446b5e67..65bd0e873 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -9,6 +9,42 @@ import mlx.core as mx import numpy as np +class MLXTestRunner(unittest.TestProgram): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def createTests(self, *args, **kwargs): + super().createTests(*args, **kwargs) + + # Asume CUDA backend in this case + device = os.getenv("DEVICE", None) + if device is not None: + device = getattr(mx, device) + else: + device = mx.default_device() + + if not (device == mx.gpu and not mx.metal.is_available()): + return + + from cuda_skip import cuda_skip + + filtered_suite = unittest.TestSuite() + + def filter_and_add(t): + if isinstance(t, unittest.TestSuite): + for sub_t in t: + filter_and_add(sub_t) + else: + t_id = ".".join(t.id().split(".")[-2:]) + if t_id in cuda_skip: + print(f"Skipping {t_id}") + else: + filtered_suite.addTest(t) + + filter_and_add(self.test) + self.test = filtered_suite + + class MLXTestCase(unittest.TestCase): @property def is_apple_silicon(self): diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 77d45dbad..213f85274 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c22e0a38f..c02b524b4 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ec9d957ea..7973d79be 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 0b4b49919..2e4e2e0c3 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 2762df8f8..eb45df124 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index f5ce496cd..656553f9d 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py index 104e7522d..cfd971fbe 100644 --- a/python/tests/test_constants.py +++ b/python/tests/test_constants.py @@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index c68315a5d..9be22e01b 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 2085e09d7..7289955ed 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 6793c98d1..d51028def 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_device_context(self): default = mx.default_device() diff = mx.cpu if default == mx.gpu else mx.gpu @@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_double.py b/python/tests/test_double.py index 10fce0db1..fccf3628f 100644 --- a/python/tests/test_double.py +++ b/python/tests/test_double.py @@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_einsum.py b/python/tests/test_einsum.py index 19ea8178e..a73ea3818 100644 --- a/python/tests/test_einsum.py +++ b/python/tests/test_einsum.py @@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index fcd424343..5d6daaec2 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.get_peak_memory() self.assertEqual(pre, post) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_multistream_deadlock(self): s1 = mx.default_stream(mx.gpu) s2 = mx.new_stream(mx.gpu) @@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 0fd8bfd87..099be0cc0 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 59c2fc3ef..13c65de99 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 8f55d41e3..a929e91cf 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) - def test_sdpa_prommote_mask(self): + def test_sdpa_promote_mask(self): mask = mx.array(2.0, mx.bfloat16) D = 64 Nq = 4 @@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index df9d25edc..07ab62672 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_graph.py b/python/tests/test_graph.py index 4b8f6d86a..7c6a11412 100644 --- a/python/tests/test_graph.py +++ b/python/tests/test_graph.py @@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 4b209736f..046a6e836 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 764d11f6e..81a43ed7f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 341564dae..35f7016c5 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 102ec857d..cbc657655 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index 7343bdc91..08da7ccc6 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 13e31ad96..10bbe821e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7c4f3f8e3..02ada39b4 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 4943fe662..e07fc8456 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 3c4f03e4d..f402bd444 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 2fc768651..551c32993 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 9012216ba..2b899c099 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index 63018fdae..bacf6e71d 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 86f41b6e8..631853cce 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 52f1a49ad..a88e59585 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() From c552ff2451f5ab6b6ff2916c53c5ab85b27943d3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 08:45:40 -0700 Subject: [PATCH 108/156] [CUDA] Fix back-end bugs and enable corresponding tests (#2296) * Fix some cuda back-end bugs and enable corresponding tests * more fixes * enable more tests * format --- docs/src/usage/indexing.rst | 10 +++++ mlx/backend/cuda/binary.cu | 20 ++++++--- mlx/backend/cuda/copy.cu | 3 +- mlx/backend/cuda/copy/copy.cuh | 21 ++++------ mlx/backend/cuda/copy/copy_contiguous.cu | 3 +- mlx/backend/cuda/copy/copy_general.cu | 8 ++-- mlx/backend/cuda/copy/copy_general_dynamic.cu | 8 ++-- mlx/backend/cuda/copy/copy_general_input.cu | 8 ++-- mlx/backend/cuda/device/cast_op.cuh | 12 ++++++ mlx/backend/cuda/device/unary_ops.cuh | 25 +++++++++-- mlx/backend/cuda/kernel_utils.cuh | 23 ++++++++-- mlx/backend/cuda/ternary.cu | 5 ++- mlx/backend/cuda/unary.cu | 11 +++-- python/tests/cuda_skip.py | 12 ------ python/tests/test_array.py | 2 +- python/tests/test_ops.py | 42 ++----------------- 16 files changed, 115 insertions(+), 98 deletions(-) diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index c74e357fa..dcbc84c1b 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -107,6 +107,16 @@ same array: >>> a array([1, 2, 0], dtype=int32) + +Note, unlike NumPy, updates to the same location are nondeterministic: + +.. code-block:: shell + + >>> a = mx.array([1, 2, 3]) + >>> a[[0, 0]] = mx.array([4, 5]) + +The first element of ``a`` could be ``4`` or ``5``. + Transformations of functions which use in-place updates are allowed and work as expected. For example: diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 47efc44d2..d4df06f18 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -165,7 +165,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides)); @@ -178,7 +178,7 @@ void binary_op_gpu_inplace( a.data(), b.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -196,8 +196,8 @@ void binary_op_gpu_inplace( } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), @@ -264,7 +264,6 @@ BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) BINARY_GPU(Remainder) -BINARY_GPU(Equal) BINARY_GPU(Greater) BINARY_GPU(GreaterEqual) BINARY_GPU(Less) @@ -279,6 +278,17 @@ BINARY_GPU(NotEqual) BINARY_GPU(Power) BINARY_GPU(Subtract) +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + auto op = get_primitive_string(this); + if (equal_nan_) { + binary_op_gpu(inputs, out, op, s); + } else { + binary_op_gpu(inputs, out, op, s); + } +} + void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 8649e1bf9..817860d0a 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -6,7 +6,7 @@ namespace mlx::core { void copy_gpu_inplace( - const array& in_, + const array& in, array& out, const Shape& shape, const Strides& strides_in, @@ -20,7 +20,6 @@ void copy_gpu_inplace( if (out.size() == 0) { return; } - const array& in = in_.data_shared_ptr() ? in_ : out; auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 0c1eff774..789826507 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,20 +10,13 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - if constexpr (cu::CastOp::is_castable) { \ - __VA_ARGS__; \ - } else { \ - throw std::runtime_error(fmt::format( \ - "Can not copy data from dtype {} to {}.", \ - dtype_to_string(out.dtype()), \ - dtype_to_string(in.dtype()))); \ - } \ - }); \ +#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ + using InType = cuda_type_t; \ + using OutType = cuda_type_t; \ + __VA_ARGS__; \ + }); \ }) void copy_contiguous( diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index fa79f0604..5f4c9ca8f 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -43,7 +43,8 @@ void copy_contiguous( if (ctype == CopyType::Vector) { kernel = cu::copy_v; } - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( in.data() + in_offset, out.data() + out_offset, diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 3c5b3bbb3..9f50c8a31 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -59,9 +59,9 @@ void copy_general( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -70,7 +70,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out)); @@ -81,7 +81,7 @@ void copy_general( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index b9774662a..2e1cf4fba 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -65,9 +65,9 @@ void copy_general_dynamic( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -76,7 +76,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), @@ -89,7 +89,7 @@ void copy_general_dynamic( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 4f2784927..a3bb37e53 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -54,9 +54,9 @@ void copy_general_input( MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { @@ -65,7 +65,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in)); }); @@ -75,7 +75,7 @@ void copy_general_input( kernel<<>>( in_ptr, out_ptr, - out.data_size(), + out.size(), const_param(shape), const_param(strides_in), ndim); diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 30b44d46f..f15270432 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -45,6 +45,18 @@ struct CastOp< } }; +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t>> { + static constexpr bool is_castable = true; + + __device__ SrcT operator()(SrcT x) { + return x; + } +}; + // Return an iterator that cast the value to DstT using CastOp. template __host__ __device__ auto make_cast_iterator(Iterator it) { diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index af7c30e64..efa9133b1 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -5,6 +5,8 @@ #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" +#include + namespace mlx::core::cu { struct Abs { @@ -183,21 +185,38 @@ struct Imag { struct Log { template __device__ T operator()(T x) { - return log(x); + if constexpr (cuda::std::is_same_v) { + auto r = log(cuCrealf(Abs{}(x))); + auto i = atan2f(cuCimagf(x), cuCrealf(x)); + return {r, i}; + } else { + return log(x); + } } }; struct Log2 { template __device__ T operator()(T x) { - return log2(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; + } else { + return log2(x); + } } }; struct Log10 { template __device__ T operator()(T x) { - return log10(x); + if constexpr (cuda::std::is_same_v) { + auto y = Log{}(x); + return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F}; + return y; + } else { + return log10(x); + } } }; diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 84392a1ec..b1fe875bd 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -102,6 +102,11 @@ inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; +// Type traits for detecting complex or real floating point numbers. +template +inline constexpr bool is_inexact_v = + is_floating_v || cuda::std::is_same_v; + // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const std::vector& vec) { @@ -136,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) { template inline std::tuple get_launch_args( T kernel, - const array& arr, + size_t size, + const Shape& shape, + const Strides& strides, bool large, int work_per_thread = 1) { - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + size_t nthreads = cuda::ceil_div(size, work_per_thread); uint block_dim = max_occupancy_block_dim(kernel); if (block_dim > nthreads) { block_dim = nthreads; } dim3 num_blocks; if (large) { - num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread); + num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); } else { num_blocks.x = cuda::ceil_div(nthreads, block_dim); @@ -154,4 +161,14 @@ inline std::tuple get_launch_args( return std::make_tuple(num_blocks, block_dim); } +template +inline std::tuple get_launch_args( + T kernel, + const array& arr, + bool large, + int work_per_thread = 1) { + return get_launch_args( + kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index bb79d4249..02e46afc1 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -116,7 +116,7 @@ void ternary_op_gpu_inplace( b.data(), c.data(), out.data(), - out.data_size(), + out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), @@ -142,7 +142,8 @@ void ternary_op_gpu_inplace( MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { using IdxT = std::conditional_t; auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index f9d373455..d2fa96381 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -28,11 +28,14 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v) { return std::is_same_v && is_floating_v; } + if (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; @@ -91,7 +94,7 @@ void unary_op_gpu_inplace( } else { auto [shape, strides] = collapse_contiguous_dims(in); auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.data_size(), shape, strides); + in_ptr, in.size(), shape, strides); thrust::transform(policy, in_begin, in_end, out_ptr, Op()); } } else { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cda396dcb..0072db192 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,6 +1,5 @@ cuda_skip = { "TestArray.test_api", - "TestArray.test_setitem", "TestAutograd.test_cumprod_grad", "TestAutograd.test_slice_grads", "TestAutograd.test_split_against_slice", @@ -51,7 +50,6 @@ cuda_skip = { "TestEinsum.test_opt_einsum_test_cases", "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", - "TestFast.test_rope_grad", "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -89,9 +87,6 @@ cuda_skip = { "TestOps.test_argpartition", "TestOps.test_array_equal", "TestOps.test_as_strided", - "TestOps.test_atleast_1d", - "TestOps.test_atleast_2d", - "TestOps.test_atleast_3d", "TestOps.test_binary_ops", "TestOps.test_bitwise_grad", "TestOps.test_complex_ops", @@ -100,22 +95,16 @@ cuda_skip = { "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", "TestOps.test_irregular_binary_ops", - "TestOps.test_isfinite", "TestOps.test_kron", - "TestOps.test_log", - "TestOps.test_log10", "TestOps.test_log1p", - "TestOps.test_log2", "TestOps.test_logaddexp", "TestOps.test_logcumsumexp", "TestOps.test_partition", "TestOps.test_scans", - "TestOps.test_slice_update_reversed", "TestOps.test_softmax", "TestOps.test_sort", "TestOps.test_tensordot", "TestOps.test_tile", - "TestOps.test_view", "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -136,7 +125,6 @@ cuda_skip = { "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - "TestVmap.test_unary", "TestVmap.test_vmap_conv", "TestVmap.test_vmap_inverse", "TestVmap.test_vmap_svd", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c02b524b4..3ab41bef7 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase): check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) check_slices( - np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) + np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1]) ) # Multiple slices diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 02ada39b4..8521d8f80 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqualArray(result, mx.array(expected)) def test_atleast_1d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_2d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_atleast_3d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - # Test 1D input arrays = [ [1], @@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase): for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) - self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i])) def test_issubdtype(self): self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact)) From bc53f8293f88bd94ca38ef6642cb487e240165db Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 13:14:46 -0700 Subject: [PATCH 109/156] Cuda bug fixes 2 (#2298) * more bug fixes * more bug fixes * format --- mlx/backend/cuda/binary.cu | 14 +-- mlx/backend/cuda/compiled.cpp | 2 + mlx/backend/cuda/device/binary_ops.cuh | 22 +++++ mlx/backend/cuda/device/ternary_ops.cuh | 1 + mlx/backend/cuda/device/utils.cuh | 33 ++++++-- mlx/backend/cuda/indexing.cpp | 50 +++++------ mlx/backend/cuda/ternary.cu | 6 +- mlx/backend/cuda/unary.cu | 7 +- mlx/backend/cuda/utils.cpp | 3 + python/tests/cuda_skip.py | 108 +++++++++++------------- python/tests/test_losses.py | 4 +- 11 files changed, 143 insertions(+), 107 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index d4df06f18..e8e8a8988 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -101,10 +101,12 @@ constexpr bool supports_binary_op() { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { - return std::is_same_v && - (is_floating_v || std::is_same_v); + return std::is_same_v && is_inexact_v; } - if (std::is_same_v || std::is_same_v) { + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || @@ -150,10 +152,10 @@ void binary_op_gpu_inplace( auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; - bool large = a.data_size() > UINT32_MAX || - b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index a6b8223e0..1aa7ecb92 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -130,11 +130,13 @@ struct FusedKernelBuilder { constexpr const char* g_jit_includes = R"( #include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include +#define inf cuda::std::numeric_limits::infinity() )"; void Compiled::eval_gpu( diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index b96a7f9cc..ca5ac35e6 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include #include @@ -122,6 +124,26 @@ struct LogAddExp { ? maxval : T(float(maxval) + log1p(expf(minval - maxval))); }; + + __device__ cuComplex operator()(cuComplex x, cuComplex y) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + constexpr float inf = cuda::std::numeric_limits::infinity(); + auto maxval = x > y ? x : y; + auto minval = x < y ? x : y; + if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) + return maxval; + float m = exp(cuCrealf(minval) - cuCrealf(maxval)); + cuComplex dexp{ + m * cos(cuCimagf(minval) - cuCimagf(maxval)), + m * sin(cuCimagf(minval) - cuCimagf(maxval)), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh index d1d008ac5..441845471 100644 --- a/mlx/backend/cuda/device/ternary_ops.cuh +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -1,4 +1,5 @@ // Copyright © 2025 Apple Inc. +#pragma once namespace mlx::core::cu { diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 6f9851c94..54d551992 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( template inline __host__ __device__ IdxT elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { - IdxT loc = elem_to_loc_nd<3>(elem, shape, strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } @@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* a_strides, const int64_t* b_strides, int ndim) { - auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; @@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* b_strides, const int64_t* c_strides, int ndim) { - auto [a_loc, b_loc, c_loc] = - elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; @@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; +inline __device__ cuComplex log1p(cuComplex in) { + float x = cuCrealf(in); + float y = cuCimagf(in); + float zabs = sqrt(x * x + y * y); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + auto z0 = sqrt((x + 1) * (x + 1) + y * y); + return {log(z0), theta}; + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 3603605c4..65a175fbd 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (src.size() > INT32_MAX) || (out.size() > INT32_MAX); uint32_t slice_size = std::accumulate( slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); @@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_gather, std::move(kernel_names)); @@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(out.size()); } else { - mod.append_arg(out.size()); + mod.append_arg(out.size()); } mod.append_ndim_arg(src.shape()); mod.append_ndim_arg(src.strides()); @@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (upd.size() > INT32_MAX) || (out.size() > INT32_MAX); - uint32_t upd_post_idx_size = std::accumulate( + int32_t upd_post_idx_size = std::accumulate( upd.shape().begin() + idx_ndim, upd.shape().end(), 1, - std::multiplies()); + std::multiplies()); const char* op = g_scatter_ops[reduce_type_]; std::string module_name = fmt::format( @@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_scatter, std::move(kernel_names)); @@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd.size()); } else { - mod.append_arg(upd.size()); + mod.append_arg(upd.size()); } mod.append_ndim_arg(upd.shape()); mod.append_ndim_arg(upd.strides()); @@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + mod.append_arg(upd_post_idx_size); } mod.append_ndim_arg(out.shape()); mod.append_ndim_arg(out.strides()); @@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string module_name = fmt::format( "gather_axis_{}_{}", @@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(src.strides(), axis_)); @@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { src.ndim() - 1, src.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; std::string module_name = fmt::format( @@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(upd.strides(), axis_)); @@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { idx.ndim() - 1, upd.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 02e46afc1..e33af3c80 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -101,10 +101,10 @@ void ternary_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; auto& c_strides = strides[2]; - bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || - c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index d2fa96381..e45144eda 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -27,13 +27,12 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2f5e2a4c8..4a3d8be30 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { if (dtype == bfloat16) { return "__nv_bfloat16"; } + if (dtype == complex64) { + return "cuComplex"; + } #define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ if (dtype == DTYPE) { \ return #CPP_TYPE; \ diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 0072db192..23c5fb19c 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,24 +1,50 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_cumprod_grad", - "TestAutograd.test_slice_grads", - "TestAutograd.test_split_against_slice", - "TestAutograd.test_stop_gradient", - "TestAutograd.test_topk_grad", "TestAutograd.test_update_state", - "TestAutograd.test_vjp", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_binary_ops", "TestBF16.test_reduction_ops", - "TestBlas.test_block_masked_matmul", "TestBlas.test_complex_gemm", + "TestCompile.test_compile_dynamic_dims", + "TestEinsum.test_ellipses", + "TestEinsum.test_opt_einsum_test_cases", + "TestLoad.test_load_f8_e4m3", + "TestMemory.test_memory_info", + "TestLayers.test_group_norm", + "TestLayers.test_pooling", + "TestLayers.test_quantized_embedding", + "TestLayers.test_sin_pe", + "TestLayers.test_upsample", + "TestOps.test_array_equal", + "TestOps.test_complex_ops", + "TestOps.test_dynamic_slicing", + "TestOps.test_softmax", + "TestOps.test_sort", + "TestOps.test_tile", + "TestReduce.test_axis_permutation_sums", + "TestReduce.test_dtypes", + "TestReduce.test_expand_sums", + "TestReduce.test_many_reduction_axes", + "TestUpsample.test_torch_upsample", + # DivMod NYI + "TestOps.test_divmod", + "TestEval.test_multi_output_eval_during_transform", + # Partition NYI + "TestAutograd.test_topk_grad", + "TestOps.test_argpartition", + "TestOps.test_partition", + # Block masked matmul NYI + "TestBlas.test_block_masked_matmul", + # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", - "TestBlas.test_matmul_batched", - "TestBlas.test_matrix_vector_attn", - "TestCompile.test_compile_dynamic_dims", - "TestCompile.test_compile_inf", - "TestCompile.test_inf_constant", + # Scan NYI + "TestAutograd.test_cumprod_grad", + "TestOps.test_scans", + "TestOps.test_logcumsumexp", + # Hadamard NYI + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", "TestConv.test_basic_grad_shapes", @@ -45,11 +71,11 @@ cuda_skip = { "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", - "TestEinsum.test_attention", - "TestEinsum.test_ellipses", - "TestEinsum.test_opt_einsum_test_cases", - "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestVmap.test_vmap_conv", + # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -59,52 +85,22 @@ cuda_skip = { "TestFFT.test_fft_large_numbers", "TestFFT.test_fft_shared_mem", "TestFFT.test_fftn", - "TestInit.test_orthogonal", + # Lapack ops NYI "TestLinalg.test_cholesky", "TestLinalg.test_cholesky_inv", "TestLinalg.test_eig", "TestLinalg.test_eigh", "TestLinalg.test_inverse", + "TestVmap.test_vmap_inverse", "TestLinalg.test_lu", "TestLinalg.test_lu_factor", "TestLinalg.test_pseudo_inverse", "TestLinalg.test_qr_factorization", + "TestInit.test_orthogonal", "TestLinalg.test_svd_decomposition", + "TestVmap.test_vmap_svd", "TestLinalg.test_tri_inverse", - "TestLoad.test_load_f8_e4m3", - "TestLosses.test_binary_cross_entropy", - "TestMemory.test_memory_info", - "TestLayers.test_conv1d", - "TestLayers.test_conv2d", - "TestLayers.test_elu", - "TestLayers.test_group_norm", - "TestLayers.test_hard_shrink", - "TestLayers.test_pooling", - "TestLayers.test_quantized_embedding", - "TestLayers.test_sin_pe", - "TestLayers.test_softshrink", - "TestLayers.test_upsample", - "TestOps.test_argpartition", - "TestOps.test_array_equal", - "TestOps.test_as_strided", - "TestOps.test_binary_ops", - "TestOps.test_bitwise_grad", - "TestOps.test_complex_ops", - "TestOps.test_divmod", - "TestOps.test_dynamic_slicing", - "TestOps.test_hadamard", - "TestOps.test_hadamard_grad_vmap", - "TestOps.test_irregular_binary_ops", - "TestOps.test_kron", - "TestOps.test_log1p", - "TestOps.test_logaddexp", - "TestOps.test_logcumsumexp", - "TestOps.test_partition", - "TestOps.test_scans", - "TestOps.test_softmax", - "TestOps.test_sort", - "TestOps.test_tensordot", - "TestOps.test_tile", + # Quantization NYI "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -120,12 +116,4 @@ cuda_skip = { "TestQuantized.test_small_matrix", "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", - "TestReduce.test_axis_permutation_sums", - "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", - "TestUpsample.test_torch_upsample", - "TestVmap.test_vmap_conv", - "TestVmap.test_vmap_inverse", - "TestVmap.test_vmap_svd", } diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index cbc657655..2ef1fa36c 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase): logits, targets, reduction="mean" ) expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.binary_cross_entropy( logits, targets, reduction="sum" ) expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) # With weights, no label smoothing weights = mx.array([1.0, 2.0, 1.0, 2.0]) From b8022c578a50010a785b8aaa9f4eeb8b3fe0891c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 18:49:32 -0700 Subject: [PATCH 110/156] divmod, partition, sort fixes (#2302) --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/binary.cu | 29 +-- mlx/backend/cuda/binary_two.cu | 248 +++++++++++++++++++++++++ mlx/backend/cuda/device/binary_ops.cuh | 4 +- mlx/backend/cuda/device/config.h | 2 +- mlx/backend/cuda/primitives.cu | 3 - mlx/backend/cuda/sort.cu | 21 ++- python/tests/cuda_skip.py | 12 -- 8 files changed, 271 insertions(+), 49 deletions(-) create mode 100644 mlx/backend/cuda/binary_two.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d96bb8812..ad979a13f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index e8e8a8988..9c437cde9 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -125,13 +125,12 @@ constexpr bool supports_binary_op() { template void binary_op_gpu_inplace( const std::vector& inputs, - std::vector& outputs, + array& out, std::string_view op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; - auto& out = outputs[0]; if (out.size() == 0) { return; } @@ -146,7 +145,6 @@ void binary_op_gpu_inplace( if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { auto [shape, strides] = collapse_contiguous_dims(a, b, out); @@ -219,20 +217,6 @@ void binary_op_gpu_inplace( }); } -template -void binary_op_gpu( - const std::vector& inputs, - std::vector& outputs, - std::string_view op, - const Stream& s) { - auto& a = inputs[0]; - auto& b = inputs[1]; - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt); - set_binary_op_output_data(a, b, outputs[1], bopt); - binary_op_gpu_inplace(inputs, outputs, op, s); -} - template void binary_op_gpu( const std::vector& inputs, @@ -243,8 +227,7 @@ void binary_op_gpu( auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); - std::vector outputs{out}; - binary_op_gpu_inplace(inputs, outputs, op, s); + binary_op_gpu_inplace(inputs, out, op, s); } #define BINARY_GPU(func) \ @@ -254,14 +237,6 @@ void binary_op_gpu( binary_op_gpu(inputs, out, get_primitive_string(this), s); \ } -#define BINARY_GPU_MULTI(func) \ - void func::eval_gpu( \ - const std::vector& inputs, std::vector& outputs) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = outputs[0].primitive().stream(); \ - binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ - } - BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu new file mode 100644 index 000000000..3047e39f0 --- /dev/null +++ b/mlx/backend/cuda/binary_two.cu @@ -0,0 +1,248 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[0], b[0]); + out_a[0] = out[0]; + out_b[0] = out[1]; + } +} + +template +__global__ void +binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[0], b[index]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void +binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[index], b[0]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void +binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[index], b[index]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +constexpr bool supports_binary_op() { + if (std::is_same_v) { + return std::is_same_v && + (std::is_integral_v || is_floating_v); + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out_a, bopt); + set_binary_op_output_data(a, b, out_b, bopt); + + if (out_a.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &cu::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + LARGE); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("DivMod::eval_gpu"); + auto& s = outputs[0].primitive().stream(); + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index ca5ac35e6..dc4f8e7bb 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -22,7 +22,7 @@ struct FloorDivide { if constexpr (cuda::std::is_integral_v) { return x / y; } else { - return trunc(x / y); + return truncf(x / y); } } }; @@ -132,7 +132,7 @@ struct LogAddExp { cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; } - constexpr float inf = cuda::std::numeric_limits::infinity(); + float inf = cuda::std::numeric_limits::infinity(); auto maxval = x > y ? x : y; auto minval = x < y ? x : y; if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) diff --git a/mlx/backend/cuda/device/config.h b/mlx/backend/cuda/device/config.h index 0933cc8b5..5a3402905 100644 --- a/mlx/backend/cuda/device/config.h +++ b/mlx/backend/cuda/device/config.h @@ -5,7 +5,7 @@ #pragma once // The maximum dimensions of shape/strides passed as kernel parameters. -#define MAX_NDIM 8 +#define MAX_NDIM 10 // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in // warpSize variable exists, using it would prevent compile-time optimizations. diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index c2362bea2..e32befc9c 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(ArgPartition) NO_GPU(BlockMaskedMM) NO_GPU(Convolution) -NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) @@ -83,7 +81,6 @@ NO_GPU(GatherQMM) NO_GPU(Hadamard) NO_GPU(Load) NO_GPU_MULTI(LUF) -NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index e1c2e8530..154ca5f32 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { axis += in.ndim(); } int nsort = in.shape(axis); - int nsegments = in.data_size() / nsort; int last_dim = in.ndim() - 1; // If we are not sorting the innermost dimension of a contiguous array, @@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); encoder.add_temporary(out); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); } encoder.launch_kernel([&](cudaStream_t stream) { @@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { indices.data(), out.data(), in.data_size(), - nsegments, + in.data_size() / nsort, offsets, offsets + 1, stream); @@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { in.data(), out.data(), in.data_size(), - nsegments, + in.data_size() / nsort, offsets, offsets + 1, stream); @@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { gpu_sort(stream(), inputs[0], out, axis_, false); } +void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgPartition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Partition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Partition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + } // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 23c5fb19c..36388c3c5 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,10 +1,8 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_update_state", "TestBF16.test_arg_reduction_ops", "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", - "TestCompile.test_compile_dynamic_dims", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", @@ -14,24 +12,14 @@ cuda_skip = { "TestLayers.test_quantized_embedding", "TestLayers.test_sin_pe", "TestLayers.test_upsample", - "TestOps.test_array_equal", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", "TestOps.test_softmax", - "TestOps.test_sort", - "TestOps.test_tile", "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - # DivMod NYI - "TestOps.test_divmod", - "TestEval.test_multi_output_eval_during_transform", - # Partition NYI - "TestAutograd.test_topk_grad", - "TestOps.test_argpartition", - "TestOps.test_partition", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI From cad5c0241c38191c7527f61291d57c6aacf55a70 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Jun 2025 12:03:25 -0700 Subject: [PATCH 111/156] [CUDA] synch properly waits for all tasks to finish and clear (#2303) * cuda synch properly waits for all tasks to finish and clear * fix copy --- mlx/backend/cuda/allocator.cpp | 1 - mlx/backend/cuda/copy/copy_general.cu | 13 +++++++++---- mlx/backend/cuda/device.cpp | 11 +++++++++++ mlx/backend/cuda/device.h | 3 +++ mlx/backend/cuda/eval.cpp | 2 +- mlx/backend/cuda/worker.cpp | 4 +++- python/tests/cuda_skip.py | 1 - 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 00f78fd4f..1d17d7df5 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) { return; } } - cudaFree(buf); } diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 9f50c8a31..2dc08c60a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -63,25 +63,30 @@ void copy_general( MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t; int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 8a3d66c8e..fcf7fdf5e 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace mlx::core { @@ -107,6 +108,16 @@ void CommandEncoder::commit() { worker_.commit(stream_.last_cuda_stream()); } +void CommandEncoder::synchronize() { + stream().synchronize(); + auto p = std::make_shared>(); + std::future f = p->get_future(); + add_completed_handler([p = std::move(p)]() { p->set_value(); }); + worker_.end_batch(); + worker_.commit(); + f.wait(); +} + Device& device(mlx::core::Device device) { static std::unordered_map devices; auto it = devices.find(device.index); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 5b2cc0607..744f77f62 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -123,6 +123,9 @@ class CommandEncoder { return has_gpu_work_; } + // Wait until kernels and completion handlers are finished + void synchronize(); + private: Device& device_; DeviceStream& stream_; diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index b309ad60e..21b019cd8 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -62,7 +62,7 @@ void finalize(Stream s) { void synchronize(Stream s) { nvtx3::scoped_range r("gpu::synchronize"); - cu::get_stream(s).synchronize(); + cu::get_command_encoder(s).synchronize(); } } // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index 64b5c7679..3b35c830b 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -80,7 +80,9 @@ void Worker::thread_fn() { } worker_tasks_.erase(worker_tasks_.begin(), end); } - for (auto& task : tasks) { + // Make sure tasks are cleared before the next wait + for (int i = 0; i < tasks.size(); ++i) { + auto task = std::move(tasks[i]); task(); } worker_event_.wait(batch + 1); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 36388c3c5..bcb95dbb7 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -6,7 +6,6 @@ cuda_skip = { "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", - "TestMemory.test_memory_info", "TestLayers.test_group_norm", "TestLayers.test_pooling", "TestLayers.test_quantized_embedding", From b3d7b8537610c2db2b1875deb5b1d230c47e8b7b Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 17 Jun 2025 23:55:56 -0700 Subject: [PATCH 112/156] Make ptx cache settable by environment variable (#2304) --- mlx/backend/cuda/jit_module.cpp | 72 ++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 8a033523c..af8f7dc75 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -37,36 +37,46 @@ void check_cu_error(const char* name, CUresult err) { } // Return the location of the CUDA toolkit. -const char* cuda_home() { - const char* home = std::getenv("CUDA_HOME"); - if (home) { - return home; - } - home = std::getenv("CUDA_PATH"); - if (home) { - return home; - } +const std::string& cuda_home() { + static std::string home = []() -> std::string { + const char* home = std::getenv("CUDA_HOME"); + if (home) { + return home; + } + home = std::getenv("CUDA_PATH"); + if (home) { + return home; + } #if defined(__linux__) - home = "/usr/local/cuda"; - if (std::filesystem::exists(home)) { - return home; - } + home = "/usr/local/cuda"; + if (std::filesystem::exists(home)) { + return home; + } #endif - throw std::runtime_error( - "Environment variable CUDA_HOME or CUDA_PATH is not set."); + throw std::runtime_error( + "Environment variable CUDA_HOME or CUDA_PATH is not set."); + }(); + return home; } // Get the cache directory for storing compiled results. -bool get_ptx_cache_dir(std::filesystem::path* result) { - auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; - if (!std::filesystem::is_directory(path)) { - std::error_code error; - if (!std::filesystem::create_directories(path, error)) { - return false; +const std::filesystem::path& ptx_cache_dir() { + static std::filesystem::path cache = []() -> std::filesystem::path { + std::filesystem::path cache; + if (auto c = std::getenv("MLX_PTX_CACHE"); c) { + cache = c; + } else { + cache = std::filesystem::temp_directory_path() / "mlx" / "ptx"; } - } - *result = path; - return true; + if (!std::filesystem::exists(cache)) { + std::error_code error; + if (!std::filesystem::create_directories(cache, error)) { + return std::filesystem::path(); + } + } + return cache; + }(); + return cache; } // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. @@ -75,6 +85,10 @@ bool read_cached_ptx( const std::string& module_name, std::vector* ptx, std::vector>* ptx_kernels) { + if (cache_dir.empty()) { + return false; + } + auto ptx_path = cache_dir / (module_name + ".ptx"); std::error_code error; auto ptx_size = std::filesystem::file_size(ptx_path, error); @@ -105,6 +119,10 @@ void write_cached_ptx( const std::string& module_name, const std::vector& ptx, const std::vector>& ptx_kernels) { + if (cache_dir.empty()) { + return; + } + std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); if (!ptx.empty()) { ptx_file.write(&ptx.front(), ptx.size()); @@ -184,11 +202,9 @@ JitModule::JitModule( const std::string& module_name, const KernelBuilder& builder) { // Check cache. - std::filesystem::path cache_dir; std::vector ptx; std::vector> ptx_kernels; - if (!get_ptx_cache_dir(&cache_dir) || - !read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) { + if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { // Create program. auto [source_code, kernel_names] = builder(); nvrtcProgram prog; @@ -246,7 +262,7 @@ JitModule::JitModule( } else { CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); } - write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); + write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels); } // Load module. From 76831ed83d6041eaf0f9649f00b47c1f0d7e166f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 19 Jun 2025 15:26:36 -0700 Subject: [PATCH 113/156] Build CUDA release in Circle (#2306) * cuda release * add license --- .circleci/config.yml | 65 ++++++++++++++++++++++++++++++++--- docs/src/install.rst | 59 +++++++++++++++++++++++++++++++ mlx/backend/cuda/device.cpp | 2 +- python/scripts/repair_cuda.sh | 17 +++++++++ setup.py | 8 ++++- 5 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 python/scripts/repair_cuda.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index 0ea9303db..205a930af 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,6 +16,9 @@ parameters: linux_release: type: boolean default: false + cuda_release: + type: boolean + default: false jobs: build_documentation: @@ -104,7 +107,7 @@ jobs: command: | echo "stubs" pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Run Python tests command: | @@ -162,7 +165,7 @@ jobs: command: | source env/bin/activate pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Run Python tests command: | @@ -223,7 +226,6 @@ jobs: command: | sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev - sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev python -m venv env source env/bin/activate CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ @@ -283,7 +285,7 @@ jobs: command: | source env/bin/activate pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs - run: name: Build Python package command: | @@ -342,7 +344,7 @@ jobs: CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ pip install . -v pip install typing_extensions - python setup.py generate_stubs + python setup.py generate_stubs << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python -m build --wheel @@ -356,6 +358,48 @@ jobs: - store_artifacts: path: wheelhouse/ + build_cuda_release: + parameters: + python_version: + type: string + default: "3.9" + extra_env: + type: string + default: "DEV_RELEASE=1" + machine: + image: linux-cuda-12:default + resource_class: gpu.nvidia.small.gen2 + steps: + - checkout + - run: + name: Build wheel + command: | + sudo apt-get update + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + python -m venv env + source env/bin/activate + pip install auditwheel + pip install patchelf + pip install build + pip install twine + << parameters.extra_env >> \ + CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + pip install ".[dev]" -v + python setup.py generate_stubs + << parameters.extra_env >> \ + CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + python -m build --wheel + bash python/scripts/repair_cuda.sh + - run: + name: Upload package + command: | + source env/bin/activate + twine upload wheelhouse/*.whl + - store_artifacts: + path: wheelhouse/ + workflows: build_and_test: when: @@ -625,3 +669,14 @@ workflows: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] extra_env: ["PYPI_RELEASE=1"] + cuda_test_release: + when: + and: + - equal: [ main, << pipeline.git.branch >> ] + - << pipeline.parameters.cuda_release >> + jobs: + - build_cuda_release: + matrix: + parameters: + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + extra_env: ["PYPI_RELEASE=1"] diff --git a/docs/src/install.rst b/docs/src/install.rst index 059b2cba4..22de94f90 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do: conda install conda-forge::mlx +CUDA +^^^^ + +MLX has a CUDA backend which you can use on any Linux platform with CUDA 12 +and SM 7.0 (Volta) and up. To install MLX with CUDA support, run: + +.. code-block:: shell + + pip install mlx-cuda + Troubleshooting ^^^^^^^^^^^^^^^ @@ -65,6 +75,8 @@ Build Requirements Python API ^^^^^^^^^^ +.. _python install: + To build and install the MLX python library from source, first, clone MLX from `its GitHub repo `_: @@ -107,6 +119,8 @@ IDE: C++ API ^^^^^^^ +.. _cpp install: + Currently, MLX must be built and installed from source. Similarly to the python library, to build and install the MLX C++ library start @@ -185,6 +199,7 @@ should point to the path to the built metal library. xcrun -sdk macosx --show-sdk-version + Binary Size Minimization ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the application. Once a kernel is compiled, it will be cached by the system. The Metal kernel cache persists across reboots. +Linux +^^^^^ + +To build from source on Linux (CPU only), install the BLAS and LAPACK headers. +For example on Ubuntu, run the following: + +.. code-block:: shell + + apt-get update -y + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + +From here follow the instructions to install either the :ref:`Python ` or :ref:`C++ ` APIs. + +CUDA +^^^^ + +To build from source on Linux with CUDA, install the BLAS and LAPACK headers +and the CUDA toolkit. For example on Ubuntu, run the following: + +.. code-block:: shell + + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb + dpkg -i cuda-keyring_1.1-1_all.deb + apt-get update -y + apt-get -y install cuda-toolkit-12-9 + apt-get install libblas-dev liblapack-dev liblapacke-dev -y + + +When building either the Python or C++ APIs make sure to pass the cmake flag +``MLX_BUILD_CUDA=ON``. For example, to build the Python API run: + +.. code-block:: shell + + CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" + +To build the C++ package run: + +.. code-block:: shell + + mkdir -p build && cd build + cmake .. -DMLX_BUILD_CUDA=ON && make -j + + Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index fcf7fdf5e..ba31c0e45 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -114,7 +114,7 @@ void CommandEncoder::synchronize() { std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); worker_.end_batch(); - worker_.commit(); + commit(); f.wait(); } diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh new file mode 100644 index 000000000..21e6a977a --- /dev/null +++ b/python/scripts/repair_cuda.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +auditwheel repair dist/* \ + --plat manylinux_2_35_x86_64 \ + --exclude libcublas* \ + --exclude libnvrtc* + +cd wheelhouse +repaired_wheel=$(find . -name "*.whl" -print -quit) +unzip -q "${repaired_wheel}" +core_so=$(find mlx -name "core*.so" -print -quit) +rpath=$(patchelf --print-rpath "${core_so}") +rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib +patchelf --force-rpath --set-rpath "$rpath" "$core_so" + +# Re-zip the repaired wheel +zip -r -q "${repaired_wheel}" . diff --git a/setup.py b/setup.py index d742e6595..35f2e68ef 100644 --- a/setup.py +++ b/setup.py @@ -174,20 +174,26 @@ if __name__ == "__main__": ) package_dir = {"": "python"} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} + install_requires = [] + build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") + if build_cuda: + install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] setup( - name="mlx", + name="mlx-cuda" if build_cuda else "mlx", version=get_version(), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", long_description=long_description, long_description_content_type="text/markdown", + license="MIT", url="https://github.com/ml-explore/mlx", packages=packages, package_dir=package_dir, package_data=package_data, include_package_data=True, + install_requires=install_requires, extras_require={ "dev": [ "nanobind==2.4.0", From c9a91805841ffd1032261a4dda8e298b642e0eec Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 20 Jun 2025 14:50:57 -0700 Subject: [PATCH 114/156] Cuda perf tuning (#2307) * perf tuning * fix adding inputs arrays in matmul / srot * format * fix --- mlx/backend/cuda/allocator.cpp | 12 +++++++- mlx/backend/cuda/copy.cu | 1 - mlx/backend/cuda/device/utils.cuh | 20 ++++++------- mlx/backend/cuda/matmul.cpp | 49 +++++++++++++++++++++++++------ mlx/backend/cuda/sort.cu | 5 ++-- 5 files changed, 63 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 1d17d7df5..6cc7145b5 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/worker.h" +#include "mlx/utils.h" #include #include @@ -14,9 +15,11 @@ namespace mlx::core { namespace cu { +constexpr int page_size = 16384; + CudaAllocator::CudaAllocator() : buffer_cache_( - getpagesize(), + page_size, [](CudaBuffer* buf) { return buf->size; }, [this](CudaBuffer* buf) { cuda_free(buf->data); @@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator() Buffer CudaAllocator::malloc(size_t size) { // Find available buffer from cache. + auto orig_size = size; std::unique_lock lock(mutex_); + if (size < page_size) { + size = next_power_of_2(size); + } else { + size = page_size * ((size + page_size - 1) / page_size); + } + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { // If we have a lot of memory pressure or are over the maximum cache size, diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 817860d0a..321806720 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -24,7 +24,6 @@ void copy_gpu_inplace( auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); return; diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 54d551992..6e8abdd7c 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); @@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; - c_loc += dim_idx * c_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); @@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( IdxT b_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); @@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( IdxT c_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; - c_loc += dim_idx * c_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 5a5e6182e..c32cecc03 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -162,11 +162,15 @@ class MatMul { } } - array workspace( - allocator::malloc(heuristic_.workspaceSize), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); + void* workspace_ptr = nullptr; + if (heuristic_.workspaceSize > 0) { + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + workspace_ptr = workspace.data(); + } encoder.launch_kernel([&](cudaStream_t stream) { CHECK_CUBLAS_ERROR(cublasLtMatmul( @@ -183,8 +187,8 @@ class MatMul { out, out_desc_, &heuristic_.algo, - workspace.data(), - workspace.nbytes(), + workspace_ptr, + heuristic_.workspaceSize, stream)); }); } @@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run(encoder, out.data(), a.data(), b.data()); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, @@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { b_batch_strides.back(), c_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run( + encoder, + out.data(), + a.data(), + b.data(), + c.data(), + alpha_, + beta_); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 154ca5f32..5cbffc0f4 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); - if (axis < 0) { axis += in.ndim(); } @@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { in.flags()); } + encoder.set_input_array(in); + encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { if constexpr (!std::is_same_v) { From 5adf185f861383fed84d2c0177397cf152970176 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 20 Jun 2025 17:19:46 -0700 Subject: [PATCH 115/156] Fix `update_modules()` when providing a subset (#2308) --- python/mlx/nn/layers/base.py | 2 +- python/tests/test_nn.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index af639dc4e..ce2ccb209 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -413,7 +413,7 @@ class Module(dict): f'Module does not have sub-module named "{k}".' ) elif isinstance(modules, list): - for i in range(len(dst)): + for i in range(len(modules)): current_value = dst[i] new_value = modules[i] if self.is_module(current_value) and self.is_module(new_value): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 10bbe821e..7753224b3 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): m = m.update_modules({"list": ["hi"]}) + # Allow updating a strict subset + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) + self.assertEqual(m.layers[1].weight.shape, (4, 3)) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): From 81bb9a2a9e21a54b9658a59f06d8f8b3d677f5e5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 24 Jun 2025 10:18:52 -0700 Subject: [PATCH 116/156] Compile float64 functions on CPU (#2311) --- mlx/backend/common/compiled.cpp | 4 ++++ mlx/backend/common/compiled.h | 8 ++++++-- python/src/convert.cpp | 2 ++ python/tests/test_compile.py | 12 ++++++++++++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 98c48cca9..44e2a432b 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -14,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) { return print_float_constant(os, x); case bfloat16: return print_float_constant(os, x); + case float64: + return print_float_constant(os, x); case complex64: return print_complex_constant(os, x); case int8: @@ -50,6 +52,8 @@ std::string get_type_string(Dtype d) { return "float16_t"; case bfloat16: return "bfloat16_t"; + case float64: + return "double"; case complex64: return "complex64_t"; case bool_: diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index 6fccaacd6..e92a6d0ad 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -18,8 +18,12 @@ std::string get_type_string(Dtype d); template void print_float_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); - os << std::setprecision(std::numeric_limits::digits10 + 1) - << x.item() << std::setprecision(old_precision); + if constexpr (std::is_same_v) { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } else { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } + os << x.item() << std::setprecision(old_precision); } template diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 00f8395fc..1340b663a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -205,6 +205,8 @@ nb::object to_scalar(mx::array& a) { return nb::cast(static_cast(a.item())); case mx::complex64: return nb::cast(a.item>()); + case mx::float64: + return nb::cast(a.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 656553f9d..ca33c2d3a 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,6 +2,7 @@ import gc import io +import math import unittest from functools import partial @@ -979,6 +980,17 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_double_constant(self): + with mx.stream(mx.cpu): + x = mx.array(1.0, dtype=mx.float64) + + def fun(x): + return (x + math.pi) * 2.0 + + y = fun(x).item() + y_compiled = mx.compile(fun)(x).item() + self.assertEqual(y, y_compiled) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 656ed7f7808266ae7923a010a6b1f5d166cf6256 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Jun 2025 13:03:09 -0700 Subject: [PATCH 117/156] Fix get 2d grid dims (#2316) --- mlx/backend/common/utils.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 457ecb7f7..9766e5e0c 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common( } } } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { throw std::runtime_error("Unable to safely factor shape."); } if (grid_y > grid_x) { std::swap(grid_x, grid_y); } + if (divisor > 1) { + grid_x = ((grid_x + divisor - 1) / divisor) * divisor; + } return std::make_tuple( static_cast(grid_x), static_cast(grid_y), 1); } From 2c11d10f8d8e4d124cb447af731b9199374695bb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 26 Jun 2025 22:08:18 -0700 Subject: [PATCH 118/156] Split broadcast so it is always fused in compile (#2318) --- mlx/compile.cpp | 36 +++++++++++++++++++++++++++++++++--- python/tests/test_compile.py | 23 +++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 79a55ba8f..0cb3b5a85 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -245,6 +245,30 @@ void merge(array& dst, array& src, ParentsMap& parents_map) { } } +// Any parent in the divider will continue to refer to `x` but any parent not +// in the divider will refer to a copy of the operation. +array split_one( + const array& x, + ParentsMap& parents_map, + const std::unordered_set& divider) { + array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs()); + + auto& x_parents = parents_map[x.id()]; + auto& y_parents = parents_map[y.id()]; + + for (auto it = x_parents.begin(); it != x_parents.end();) { + if (divider.find(it->first.id()) != divider.end()) { + it->first.inputs()[it->second] = y; + y_parents.emplace_back(std::move(*it)); + it = x_parents.erase(it); + } else { + it++; + } + } + + return std::move(y); +} + template std::uintptr_t get_function_address(const std::function& fun) { using FunType = T (*)(U...); @@ -669,10 +693,16 @@ void compile_fuse( } // Arrays with a mix of parents outside the compilable section - // are not fusable + // are not fusable except for broadcast which we can split to avoid + // stopping fusion if (!all_parents_in) { - // Possible input - input_set.insert(a.id()); + if (a.has_primitive() && is_broadcast(a.primitive())) { + array b = split_one(a, parents_map, cache); + recurse(b, depth, s, shape); + } else { + // Possible input + input_set.insert(a.id()); + } return; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index ca33c2d3a..ada2b1484 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -5,6 +5,7 @@ import io import math import unittest from functools import partial +from io import StringIO import mlx.core as mx import mlx_tests @@ -991,6 +992,28 @@ class TestCompile(mlx_tests.MLXTestCase): y_compiled = mx.compile(fun)(x).item() self.assertEqual(y, y_compiled) + def test_shared_broadcast(self): + def fun(x, y, z): + yy = mx.broadcast_to(y, z.shape) + return (x + yy * z), yy.sum() + + a = mx.random.normal((10, 10)) + b = mx.array(0.1) + c = mx.random.normal((10, 10)) + mx.eval(a, b, c) + fc = mx.compile(fun) + d = fc(a, b, c) + + s = StringIO() + mx.export_to_dot(s, a=a, b=b, c=c, d1=d[0], d2=d[1]) + s.seek(0) + s = s.read() + + self.assertTrue("CompiledBroadcastMultiplyAdd" in s) + d_hat = fun(a, b, c) + self.assertTrue(mx.allclose(d[0], d_hat[0])) + self.assertTrue(mx.allclose(d[1], d_hat[1])) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 772f471ff265ad21996565161fa48811b9ed6b91 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 27 Jun 2025 12:59:20 -0700 Subject: [PATCH 119/156] [CUDA] Fix reductions (#2314) --- benchmarks/python/comparative/bench_torch.py | 22 +- mlx/backend/common/reduce.cpp | 15 +- mlx/backend/common/reduce.h | 4 + mlx/backend/cuda/CMakeLists.txt | 3 +- mlx/backend/cuda/binary_two.cu | 2 +- mlx/backend/cuda/reduce.cu | 38 +- mlx/backend/cuda/reduce/all_reduce.cu | 150 ++++++++ mlx/backend/cuda/reduce/col_reduce.cu | 296 +++++++------- mlx/backend/cuda/reduce/init_reduce.cu | 50 +++ mlx/backend/cuda/reduce/reduce.cuh | 12 +- mlx/backend/cuda/reduce/reduce_ops.cuh | 55 ++- mlx/backend/cuda/reduce/reduce_utils.cuh | 158 ++++++++ mlx/backend/cuda/reduce/row_reduce.cu | 383 ++++++++++++------- mlx/backend/cuda/reduce/segmented_reduce.cu | 84 ---- mlx/backend/cuda/softmax.cu | 4 +- python/tests/cuda_skip.py | 5 - 16 files changed, 862 insertions(+), 419 deletions(-) create mode 100644 mlx/backend/cuda/reduce/all_reduce.cu create mode 100644 mlx/backend/cuda/reduce/init_reduce.cu create mode 100644 mlx/backend/cuda/reduce/reduce_utils.cuh delete mode 100644 mlx/backend/cuda/reduce/segmented_reduce.cu diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..dd3436d9a 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -5,6 +5,7 @@ import os import time import torch +import torch.cuda import torch.mps @@ -44,8 +45,10 @@ def bench(f, *args): def sync_if_needed(x): - if x.device != torch.device("cpu"): + if x.device == torch.device("mps"): torch.mps.synchronize() + elif x.device == torch.device("cuda"): + torch.cuda.synchronize() @torch.no_grad() @@ -99,6 +102,14 @@ def reduction(op, axis, x): sync_if_needed(x) +@torch.no_grad() +def sum_and_add(axis, x, y): + z = x.sum(axis=axis, keepdims=True) + for i in range(50): + z = (z + y).sum(axis=axis, keepdims=True) + sync_if_needed(x) + + @torch.no_grad() def softmax(axis, x): ys = [] @@ -340,7 +351,11 @@ if __name__ == "__main__": args.axis.pop(0) torch.set_num_threads(1) - device = "cpu" if args.cpu else "mps" + device = "mps" + if torch.cuda.is_available(): + device = "cuda" + if args.cpu: + device = "cpu" types = args.dtype if not types: @@ -460,5 +475,8 @@ if __name__ == "__main__": elif args.benchmark == "selu": print(bench(selu, x)) + elif args.benchmark == "sum_and_add": + print(bench(sum_and_add, axis, *xs)) + else: raise ValueError(f"Unknown benchmark `{args.benchmark}`.") diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 5c7f63b75..ceef46400 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -5,11 +5,9 @@ namespace mlx::core { std::pair shapes_without_reduction_axes( - const array& x, + Shape shape, + Strides strides, const std::vector& axes) { - auto shape = x.shape(); - auto strides = x.strides(); - for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; shape.erase(shape.begin() + a); @@ -19,6 +17,15 @@ std::pair shapes_without_reduction_axes( return std::make_pair(shape, strides); } +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + auto shape = x.shape(); + auto strides = x.strides(); + return shapes_without_reduction_axes( + std::move(shape), std::move(strides), axes); +} + ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index ddb5c3492..8b24f4f53 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); } // namespace mlx::core diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index ad979a13f..8130d396f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -29,9 +29,10 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 3047e39f0..074c947da 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -157,7 +157,7 @@ void binary_op_gpu_inplace( if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = - &cu::binary_g_nd; + cu::binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large); kernel<<>>( diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index a740113db..8350eebb7 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(!axes_.empty()); assert(out.size() != in.size()); - out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); - // Fill out with init value. if (in.size() == 0) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - thrust::fill_n( - cu::thrust_policy(stream), - thrust::device_pointer_cast(out.data()), - out.data_size(), - cu::ReduceInit::value()); - }); - }); - }); + init_reduce(encoder, in, out, reduce_type_); return; } @@ -51,7 +34,19 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. - if (plan.type == GeneralReduce) { + // + // TODO: Instead of copying we can use elem-to-loc to deal with broadcasting + // like we do in Metal. When it comes to broadcasted reduction axes + // some can be ignored eg for min/max. + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); encoder.add_temporary(in_copy); @@ -59,9 +54,8 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan = get_reduction_plan(in, axes_); } - if ((plan.type == ContiguousAllReduce) || - (plan.type == ContiguousReduce && plan.shape.size() == 1)) { - segmented_reduce(encoder, in, out, reduce_type_, axes_, plan); + if (plan.type == ContiguousAllReduce) { + all_reduce(encoder, in, out, reduce_type_); return; } diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu new file mode 100644 index 000000000..5a7c28041 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -0,0 +1,150 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[N]; + U accs[M]; + accs[0] = init; + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], __cast(vals[j])); + } + } + + if (i < check) { + cub::LoadDirectBlocked( + block.thread_rank(), in + i, vals, check - i, __cast(init)); + for (int i = 0; i < N; i++) { + accs[0] = op(accs[0], __cast(vals[i])); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +} // namespace cu + +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 8; + + out.set_data(allocator::malloc(out.nbytes())); + + auto get_args = [](size_t size, int N) { + int threads = std::min(512UL, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + size_t insize = in.size(); + Dtype dt = in.dtype(); + + // Cub doesn't like const pointers for load (sigh). + void* indata = const_cast(in.data()); + + // Large array so allocate an intermediate and accumulate there + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(in); + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + static_cast(indata), + intermediate.data(), + block_step, + insize); + }); + }); + }); + + // Set the input for the next step and recalculate the blocks + indata = intermediate.data(); + dt = intermediate.dtype(); + insize = intermediate.size(); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(intermediate); + } + + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::all_reduce; + kernel<<>>( + static_cast(indata), out.data(), block_step, insize); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 9911a6fe0..192a9b3e8 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -36,19 +38,36 @@ struct ColReduceArgs { const array& in, const ReductionPlan& plan, const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + assert(!plan.shape.empty()); reduction_size = plan.shape.back(); reduction_stride = plan.strides.back(); int64_t stride_back = 1; - auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); while (!shape_vec.empty() && stride_back < reduction_stride) { stride_back *= shape_vec.back(); shape_vec.pop_back(); strides_vec.pop_back(); } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); + collapse_contiguous_dims(sorted_shape, sorted_strides); shape = const_param(shape_vec); strides = const_param(strides_vec); ndim = shape_vec.size(); @@ -64,86 +83,6 @@ struct ColReduceArgs { } }; -template -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - int column = - grid.block_index().x * block.dim_threads().x + block.thread_index().x; - if (column * N_READS >= args.reduction_stride) { - return; - } - - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - Op op; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = ReduceInit::value(); - } - - // Read input to local. - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next( - block.thread_index().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - for (size_t r = block.thread_index().y; - r < args.non_col_reductions * args.reduction_size; - r += block.dim_threads().y) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location()), - vals, - args.reduction_stride, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - loop.next( - block.dim_threads().y, - args.reduce_shape.data(), - args.reduce_strides.data()); - } - - // Do block reduce when each column has more than 1 element to reduce. - if (block.dim_threads().y > 1) { - __shared__ U shared_vals[32 * 8 * N_READS]; - size_t col = - block.thread_index().y * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - shared_vals[col * N_READS + i] = totals[i]; - } - block.sync(); - if (block.thread_index().y == 0) { - for (int i = 0; i < N_READS; i++) { - totals[i] = shared_vals[block.thread_index().x * N_READS + i]; - } - for (int j = 1; j < block.dim_threads().y; j++) { - col = j * block.dim_threads().x + block.thread_index().x; - for (int i = 0; i < N_READS; i++) { - totals[i] = op(shared_vals[col * N_READS + i], totals[i]); - } - } - } - } - - // Write result. - if (block.thread_index().y == 0) { - cub::StoreDirectBlocked( - column, - out + out_idx * args.reduction_stride, - totals, - args.reduction_stride); - } -} - template < typename T, typename U, @@ -152,67 +91,94 @@ template < int BM, int BN, int N_READS = 4> -__global__ void col_reduce_looped( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args) { +__global__ void +col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - constexpr int n_warps = BN / N_READS; + constexpr int threads_per_row = BN / N_READS; - int out_idx = grid.block_rank() / grid.dim_blocks().x; - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals Op op; U totals[N_READS]; for (int i = 0; i < N_READS; i++) { totals[i] = ReduceInit::value(); } - // Read input to local. - int r = block.thread_rank() / n_warps; - int column = block.thread_rank() % n_warps; - int in_offset = grid.block_index().x * BN; LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(r, args.reduce_shape.data(), args.reduce_strides.data()); - for (; r < args.non_col_reductions * args.reduction_size; r += BM) { - U vals[N_READS]; - cub::LoadDirectBlocked( - column, - make_cast_iterator(in + loop.location() + in_offset), - vals, - args.reduction_stride - in_offset, - ReduceInit::value()); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); + loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data()); + size_t total = args.non_col_reductions * args.reduction_size; + if (tile_x * BN + BN <= args.reduction_stride) { + if (args.reduction_stride % N_READS == 0) { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + } else { + for (size_t r = thread_y; r < total; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + __cast(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], __cast(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } // Do warp reduce for each output. - constexpr int n_outputs = BN / n_warps; + constexpr int n_outputs = BN / threads_per_row; static_assert(BM == 32 && n_outputs == N_READS); __shared__ U shared_vals[BM * BN]; - size_t col = block.thread_index().y * BN + block.thread_index().x * N_READS; + short s_idx = thread_y * BN + thread_x * N_READS; for (int i = 0; i < N_READS; i++) { - shared_vals[col + i] = totals[i]; + shared_vals[s_idx + i] = totals[i]; } block.sync(); - col = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; for (int i = 0; i < n_outputs; i++) { - totals[i] = cg::reduce(warp, shared_vals[col + i], op); + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); } // Write result. if (warp.thread_rank() == 0) { - size_t out_offset = grid.block_index().x * BN; cub::StoreDirectBlocked( warp.meta_group_rank(), - out + out_idx * args.reduction_stride + out_offset, + out + tile_y * args.reduction_stride + tile_x * BN, totals, - args.reduction_stride - out_offset); + args.reduction_stride - tile_x * BN); } } @@ -220,14 +186,55 @@ __global__ void col_reduce_looped( inline auto output_grid_for_col_reduce( const array& out, - const cu::ColReduceArgs& args) { - auto out_shape = out.shape(); - auto out_strides = out.strides(); - while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { - out_shape.pop_back(); - out_strides.pop_back(); + const cu::ColReduceArgs& args, + int bn) { + int gx, gy = 1; + size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); + size_t n_outer_blocks = out.size() / args.reduction_stride; + size_t n_blocks = n_outer_blocks * n_inner_blocks; + while (n_blocks / gy > INT32_MAX) { + gy *= 2; } - return get_2d_grid_dims(out_shape, out_strides); + gx = cuda::ceil_div(n_blocks, gy); + + return dim3(gx, gy, 1); +} + +void col_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::ColReduceArgs args) { + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = cu::col_reduce_looped; + kernel<<>>(indata, out.data(), args); + }); + }); + }); + }); } void col_reduce( @@ -237,42 +244,23 @@ void col_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current col reduce options + // + // - col_reduce_looped + // + // It is a general strided reduce. Each threadblock computes the output for + // a subrow of the fast moving axis. For instance 32 elements. + // + // Notes: As in row reduce we opt to read as much in order as possible and + // leave transpositions as they are (contrary to our Metal backend). + // + // Moreover we need different kernels for short rows and tuning + + // Make the args struct to help route to the best kernel cu::ColReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr int N_READS = 4; - dim3 block_dims; - dim3 num_blocks = output_grid_for_col_reduce(out, args); - num_blocks.z = num_blocks.y; - num_blocks.y = num_blocks.x; - auto kernel = - cu::col_reduce_small; - size_t total = args.non_col_reductions * args.reduction_size; - if (total < 32) { - size_t stride_blocks = - cuda::ceil_div(args.reduction_stride, N_READS); - block_dims.x = std::min(stride_blocks, 32ul); - block_dims.y = std::min(total, 8ul); - num_blocks.x = cuda::ceil_div(stride_blocks, block_dims.x); - } else { - constexpr int BM = 32; - constexpr int BN = 32; - block_dims.x = BM * BN / N_READS; - num_blocks.x = cuda::ceil_div(args.reduction_stride, BN); - kernel = cu:: - col_reduce_looped; - } - kernel<<>>( - in.data(), out.data(), args); - }); - }); - }); - }); + // Fallback col reduce + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu new file mode 100644 index 000000000..50fe109c4 --- /dev/null +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -0,0 +1,50 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void init_reduce(U* out, size_t size) { + auto index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = ReduceInit::value(); + } +} + +} // namespace cu + +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + // Allocate if needed + if (out.data_shared_ptr() == nullptr) { + out.set_data(allocator::malloc(out.nbytes())); + } + + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + kernel<<>>(out.data(), out.size()); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a673e052e..a7262bcc2 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -47,13 +47,11 @@ namespace mlx::core { throw std::invalid_argument("Unknown reduce type."); \ } -void segmented_reduce( +void all_reduce( cu::CommandEncoder& encoder, const array& in, array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan); + Reduce::ReduceType reduce_type); void row_reduce( cu::CommandEncoder& encoder, @@ -71,4 +69,10 @@ void col_reduce( const std::vector& axes, const ReductionPlan& plan); +void init_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 832787222..b40d2bd4e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -3,48 +3,89 @@ #pragma once #include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { // Reduce ops. struct And { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Or { - __device__ bool operator()(bool a, bool b) { + __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } + + __device__ void atomic_update(bool* x, bool y) { + atomic_reduce(x, y); + } }; struct Sum { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a + b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } + + __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(int* x, int y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(float* x, float y) { + atomicAdd(x, y); + } }; struct Prod { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a * b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Min { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a < b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; struct Max { template - __device__ T operator()(T a, T b) { + __device__ __forceinline__ T operator()(T a, T b) { return a > b ? a : b; } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } }; // Traits to get the result type of reduce op. @@ -120,7 +161,7 @@ template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (cuda::std::is_same_v) { - return T{1, 1}; + return T{1, 0}; } else { return typename ReduceResult::type{1}; } diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh new file mode 100644 index 000000000..d4670503a --- /dev/null +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -0,0 +1,158 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/cuda/device/utils.cuh" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + +// TODO: Should make a custom complex type +template +inline __device__ U __cast(T x) { + return static_cast(x); +} + +template <> +inline __device__ bool __cast(cuComplex x) { + return x.x != 0 && x.y != 0; +} + +template <> +inline __device__ cuComplex __cast(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); +} + +template +inline __device__ void +block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { + // First reduce in the current warp + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + + // Reduce across warps + if (warp.meta_group_size() > 1) { + if (warp.thread_rank() == 0) { + for (int i = 0; i < N; i++) { + smem[warp.meta_group_rank() * N + i] = vals[i]; + } + } + block.sync(); + if (warp.thread_rank() < warp.meta_group_size()) { + for (int i = 0; i < N; i++) { + vals[i] = smem[warp.thread_rank() * N + i]; + } + } else { + for (int i = 0; i < N; i++) { + vals[i] = init; + } + } + for (int i = 0; i < N; i++) { + vals[i] = cg::reduce(warp, vals[i], op); + } + } +} + +} // namespace cu + +inline void allocate_same_layout( + array& out, + const array& in, + const std::vector& axes) { + if (in.flags().row_contiguous) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + if (out.ndim() < in.ndim()) { + throw std::runtime_error( + "Reduction without keepdims only supported for row-contiguous inputs"); + } + + // Calculate the transpositions applied to in in order to apply them to out. + std::vector axis_order(in.ndim()); + std::iota(axis_order.begin(), axis_order.end(), 0); + std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { + return in.strides(left) > in.strides(right); + }); + + // Transpose the shape and calculate the strides + Shape out_shape(in.ndim()); + Strides out_strides(in.ndim(), 1); + for (int i = 0; i < in.ndim(); i++) { + out_shape[i] = out.shape(axis_order[i]); + } + for (int i = in.ndim() - 2; i >= 0; i--) { + out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; + } + + // Reverse the axis order to get the final strides + Strides final_strides(in.ndim()); + for (int i = 0; i < in.ndim(); i++) { + final_strides[axis_order[i]] = out_strides[i]; + } + + // Calculate the resulting contiguity and do the memory allocation + auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = true; + out.set_data( + allocator::malloc(out.nbytes()), + data_size, + final_strides, + fl, + allocator::free); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index ae54a27d6..6a8a35311 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" @@ -55,84 +57,108 @@ struct RowReduceArgs { non_row_reductions *= reduce_shape[i]; } } + + // Convert shape and strides as if in was contiguous + void sort_access_pattern(const array& in, const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } }; -template -__global__ void row_reduce_small( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { - size_t out_idx = cg::this_grid().thread_rank(); - if (out_idx >= out_size) { - return; - } - - Op op; - - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = 0; n < args.non_row_reductions; n++) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); - } - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - - out[out_idx] = total_val; -} - -template -__global__ void row_reduce_small_warp( - const T* in, - U* out, - size_t out_size, - const __grid_constant__ RowReduceArgs args) { +template +__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - size_t out_idx = grid.thread_rank() / WARP_SIZE; - if (out_idx >= out_size) { - return; + const U init = cu::ReduceInit::value(); + ReduceOp op; + + T vals[M][N]; + U accs[M]; + for (int i = 0; i < M; i++) { + accs[i] = init; } - Op op; + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); + in += start_row * size; + out += start_row; - U total_val = ReduceInit::value(); - LoopedElemToLoc 2)> loop(args.reduce_ndim); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - - for (size_t n = warp.thread_rank(); n < args.non_row_reductions; - n += WARP_SIZE) { - for (int r = 0; r < cuda::ceil_div(args.row_size, N_READS); r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); + if (size % N == 0) { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } + } else { + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + r * (block.size() * N), + vals[k]); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } } - loop.next(WARP_SIZE, args.reduce_shape.data(), args.reduce_strides.data()); } - total_val = cg::reduce(warp, total_val, op); + if (final_offset < size) { + for (int k = 0; k < M; k++) { + cub::LoadDirectBlocked( + block.thread_rank(), + in + k * size + final_offset, + vals[k], + size, + __cast(init)); + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], __cast(vals[k][j])); + } + } + } - if (warp.thread_rank() == 0) { - out[out_idx] = total_val; + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + for (int i = 0; i < M; i++) { + out[i] = accs[i]; + } + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } } } @@ -141,55 +167,165 @@ template < typename U, typename Op, int NDIM, - int BLOCK_DIM_X, + int BLOCK_DIM, int N_READS = 4> __global__ void row_reduce_looped( - const T* in, + T* in, U* out, size_t out_size, const __grid_constant__ RowReduceArgs args) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); - size_t out_idx = grid.thread_rank() / BLOCK_DIM_X; - if (out_idx >= out_size) { - return; - } + size_t out_idx = grid.block_rank(); Op op; - U total_val = ReduceInit::value(); + U total[1]; + U init = ReduceInit::value(); + total[0] = init; LoopedElemToLoc 2)> loop(args.reduce_ndim); + size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); + size_t final_offset = full_blocks * BLOCK_DIM * N_READS; in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); for (size_t n = 0; n < args.non_row_reductions; n++) { - for (size_t r = 0; r < cuda::ceil_div(args.row_size, BLOCK_DIM_X * N_READS); - r++) { - U vals[N_READS]; - cub::LoadDirectBlocked( - r * BLOCK_DIM_X + block.thread_index().x, - make_cast_iterator(in + loop.location()), - vals, - args.row_size, - ReduceInit::value()); - total_val = op(total_val, cub::ThreadReduce(vals, op)); + for (size_t r = 0; r < full_blocks; r++) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized( + block.thread_rank(), + in + loop.location() + r * BLOCK_DIM * N_READS, + vals); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } } + if (final_offset < args.row_size) { + T vals[N_READS]; + cub::LoadDirectBlocked( + block.thread_rank(), + in + loop.location() + final_offset, + vals, + args.row_size - final_offset, + __cast(init)); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], __cast(vals[i])); + } + } + // TODO: Maybe block.sync() here? loop.next(args.reduce_shape.data(), args.reduce_strides.data()); } - typedef cub::BlockReduce BlockReduceT; - __shared__ typename BlockReduceT::TempStorage temp; - - total_val = BlockReduceT(temp).Reduce(total_val, op); + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); if (block.thread_rank() == 0) { - out[out_idx] = total_val; + out[out_idx] = total[0]; } } } // namespace cu +void row_reduce_simple( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to avoid elem_to_loc in the + // kernel. + allocate_same_layout(out, in, axes); + + // TODO: If out.size() < 1024 which will be a common case then write this in + // 2 passes. Something like 32 * out.size() and then do a warp reduce. + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } + + // Launch + kernel<<>>( + indata, out.data(), out.size(), plan.shape.back()); + }); + }); + }); +} + +void row_reduce_looped( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + cu::RowReduceArgs args) { + constexpr int N_READS = 8; + + // Allocate data for the output using in's layout to access them as + // contiguously as possible. + allocate_same_layout(out, in, axes); + + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + using T = cuda_type_t; + using U = cu::ReduceResult::type; + + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); + + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); + + // Pick the kernel + auto kernel = cu::row_reduce_looped; + MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { + MLX_SWITCH_BLOCK_DIM(threads, THREADS, { + kernel = cu::row_reduce_looped; + block.x = THREADS; + }); + }); + + // Launch + kernel<<>>( + indata, out.data(), out.size(), args); + }); + }); + }); +} + void row_reduce( cu::CommandEncoder& encoder, const array& in, @@ -197,54 +333,35 @@ void row_reduce( Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan) { + // Current row reduction options + // + // - row_reduce_simple + // + // That means that we are simply reducing across the fastest moving axis. + // We are reducing 1 or 2 rows per threadblock depending on the size of + // output. + // + // - row_reduce_looped + // + // It is a general row reduction. We are computing 1 output per + // threadblock. We read the fastest moving axis vectorized and loop over + // the rest of the axes. + // + // Notes: We opt to read as much in order as possible and leave + // transpositions as they are (contrary to our Metal backend). + + // Simple row reduce means that we have 1 axis that we are reducing over and + // it has stride 1. + if (plan.shape.size() == 1) { + row_reduce_simple(encoder, in, out, reduce_type, axes, plan); + return; + } + + // Make the args struct to help route to the best kernel cu::RowReduceArgs args(in, plan, axes); - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - using InType = cuda_type_t; - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using OutType = cu::ReduceResult::type; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - constexpr size_t N_READS = 4; - dim3 out_dims = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims, num_blocks; - auto kernel = - cu::row_reduce_small; - if (args.row_size <= 64) { - if ((args.non_row_reductions < 32 && args.row_size <= 8) || - (args.non_row_reductions <= 8)) { - block_dims.x = std::min(out_dims.x, 1024u); - num_blocks.x = cuda::ceil_div(out_dims.x, block_dims.x); - num_blocks.y = out_dims.y; - } else { - block_dims.x = WARP_SIZE; - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - kernel = - cu::row_reduce_small_warp; - } - } else { - size_t num_threads = cuda::ceil_div(args.row_size, N_READS); - num_threads = cuda::ceil_div(num_threads, WARP_SIZE) * WARP_SIZE; - MLX_SWITCH_BLOCK_DIM(num_threads, BLOCK_DIM_X, { - num_blocks.y = out_dims.x; - num_blocks.z = out_dims.y; - block_dims.x = BLOCK_DIM_X; - kernel = cu::row_reduce_looped< - InType, - OutType, - OP, - NDIM, - BLOCK_DIM_X, - N_READS>; - }); - } - kernel<<>>( - in.data(), out.data(), out.size(), args); - }); - }); - }); - }); + // Fallback row reduce + row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args)); } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu deleted file mode 100644 index 114d71809..000000000 --- a/mlx/backend/cuda/reduce/segmented_reduce.cu +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include -#include - -namespace mlx::core { - -template -void cub_all_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceReduce::Reduce(temp.data(), size, args...)); -} - -template -void cub_segmented_reduce(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR(cub::DeviceSegmentedReduce::Reduce(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR( - cub::DeviceSegmentedReduce::Reduce(temp.data(), size, args...)); -} - -struct MultiplyOp { - int factor; - __device__ int operator()(int i) { - return i * factor; - } -}; - -void segmented_reduce( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type, - const std::vector& axes, - const ReductionPlan& plan) { - encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using InType = cuda_type_t; - using OutType = cu::ReduceResult::type; - auto in_iter = cu::make_cast_iterator( - thrust::device_pointer_cast(in.data())); - auto out_ptr = thrust::device_pointer_cast(out.data()); - auto init = cu::ReduceInit::value(); - - if (plan.type == ContiguousAllReduce) { - cub_all_reduce( - encoder, in_iter, out_ptr, in.data_size(), OP(), init, stream); - } else if (plan.type == ContiguousReduce) { - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), MultiplyOp{plan.shape.back()}); - cub_segmented_reduce( - encoder, - in_iter, - out_ptr, - out.size(), - offsets, - offsets + 1, - OP(), - init, - stream); - } else { - throw std::runtime_error("Unsupported plan in segmented_reduce."); - } - }); - }); - }); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index fc001ae75..652e6da19 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -51,7 +51,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { make_cast_iterator(in), vals, axis_size, - Limits::finite_min()); + Limits::min()); prevmax = maxval; maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); // Online normalizer calculation for softmax: @@ -79,7 +79,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] - : Limits::finite_min(); + : Limits::min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index bcb95dbb7..cba642ca1 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,6 @@ cuda_skip = { "TestArray.test_api", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", @@ -13,11 +12,7 @@ cuda_skip = { "TestLayers.test_upsample", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", - "TestOps.test_softmax", - "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", From 33bf1a244b8cd0a58c9a9363c13b46c1714221d5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 29 Jun 2025 11:12:29 -0700 Subject: [PATCH 120/156] Fix module update in strict mode (#2321) * fix module update in strict mode * allow GELU to be pickled --- python/mlx/nn/layers/activations.py | 21 ++++++++++----------- python/mlx/nn/layers/base.py | 4 ++-- python/tests/test_nn.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..21994c0e6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -546,7 +546,7 @@ class GELU(Module): See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the functional equivalents and information regarding error bounds. - + Args: approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. @@ -554,20 +554,19 @@ class GELU(Module): def __init__(self, approx="none"): super().__init__() - - if approx == "none": - self._act = gelu - elif approx == "precise" or approx == "tanh": - self._act = gelu_approx - elif approx == "fast": - self._act = gelu_fast_approx - else: + self._approx = approx + allowed = ["none", "precise", "tanh", "fast"] + if approx not in allowed: raise ValueError( - f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" + f"The approximation should be in {allowed} but '{approx}' was given" ) def __call__(self, x): - return self._act(x) + if self._approx == "none": + return gelu(x) + elif self._approx in ["precise", "tanh"]: + return gelu_approx(x) + return gelu_fast_approx(x) @_make_activation_module(tanh) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index ce2ccb209..4a548c80d 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -404,7 +404,7 @@ class Module(dict): dst[k] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) - elif strict: + elif strict and new_value != {}: raise ValueError( f"Received invalid type: {type(new_value).__name__}." ) @@ -420,7 +420,7 @@ class Module(dict): dst[i] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) - elif strict: + elif strict and new_value != {}: raise ValueError( f"Received invalid type: {type(new_value).__name__}." ) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7753224b3..53bcb3141 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase): m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) self.assertEqual(m.layers[1].weight.shape, (4, 3)) + # Using leaf_modules in the update should always work + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)] + self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0} + + m = MyModel() + m.update_modules(m.leaf_modules()) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): From 3d5e17e507b77ab08c6b04150ac77a51a350b2ce Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 1 Jul 2025 01:33:44 -0700 Subject: [PATCH 121/156] MLX_SWITCH macros to templates (#2320) --- mlx/backend/cuda/arg_reduce.cu | 50 ++-- mlx/backend/cuda/binary.cu | 94 ++++--- mlx/backend/cuda/binary_two.cu | 98 ++++--- mlx/backend/cuda/copy/copy.cuh | 9 - mlx/backend/cuda/copy/copy_contiguous.cu | 30 +- mlx/backend/cuda/copy/copy_general.cu | 78 +++--- mlx/backend/cuda/copy/copy_general_dynamic.cu | 81 +++--- mlx/backend/cuda/copy/copy_general_input.cu | 66 +++-- mlx/backend/cuda/kernel_utils.cuh | 86 +++--- mlx/backend/cuda/layer_norm.cu | 66 +++-- mlx/backend/cuda/logsumexp.cu | 15 +- mlx/backend/cuda/primitives.cu | 3 +- mlx/backend/cuda/reduce/all_reduce.cu | 18 +- mlx/backend/cuda/reduce/col_reduce.cu | 14 +- mlx/backend/cuda/reduce/init_reduce.cu | 9 +- mlx/backend/cuda/reduce/reduce.cuh | 62 ++--- mlx/backend/cuda/reduce/row_reduce.cu | 32 ++- mlx/backend/cuda/rms_norm.cu | 63 +++-- mlx/backend/cuda/rope.cu | 17 +- mlx/backend/cuda/softmax.cu | 21 +- mlx/backend/cuda/sort.cu | 14 +- mlx/backend/cuda/ternary.cu | 95 ++++--- mlx/backend/cuda/unary.cu | 6 +- mlx/backend/cuda/utils.cpp | 46 ++- mlx/dtype_utils.cpp | 40 ++- mlx/dtype_utils.h | 263 ++++++------------ mlx/utils.cpp | 9 +- 27 files changed, 693 insertions(+), 692 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index c8a5a962a..90f8561c1 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -152,35 +152,29 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { - using InType = cuda_type_t; + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{BLOCK_DIM, 1, 1}; - auto kernel = &cu::arg_reduce_general< - InType, - cu::ArgMax, - BLOCK_DIM, - N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = &cu::arg_reduce_general< - InType, - cu::ArgMin, - BLOCK_DIM, - N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu:: + arg_reduce_general, block_dim(), N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 9c437cde9..8e476d30f 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -140,54 +140,64 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - &cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -197,7 +207,7 @@ void binary_op_gpu_inplace( kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 074c947da..0a68e5f1d 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -138,57 +138,67 @@ void binary_op_gpu_inplace( encoder.set_output_array(out_a); encoder.set_output_array(out_b); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -202,7 +212,7 @@ void binary_op_gpu_inplace( out_a.data_size(), out_a.shape(), out_a.strides(), - LARGE); + large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 789826507..e80fdec8c 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,15 +10,6 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - __VA_ARGS__; \ - }); \ - }) - void copy_contiguous( cu::CommandEncoder& encoder, CopyType ctype, diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 5f4c9ca8f..15858ded0 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -36,19 +36,23 @@ void copy_contiguous( int64_t in_offset, int64_t out_offset) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel<<>>( + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 2dc08c60a..b2703e4bf 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -56,42 +56,48 @@ void copy_general( const Strides& strides_in, const Strides& strides_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 2e1cf4fba..68ad005d2 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -62,41 +62,52 @@ void copy_general_dynamic( const array& dynamic_offset_in, const array& dynamic_offset_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_dynamic_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::copy_gg_dynamic_nd< + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index a3bb37e53..d83ba0854 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -51,35 +51,43 @@ void copy_general_input( const Shape& shape, const Strides& strides_in) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_g_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b1fe875bd..b0058b618 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -6,6 +6,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" @@ -17,60 +19,46 @@ namespace mlx::core { -// Convert a number between 1~3 to constexpr. -#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ - switch (N) { \ - case 1: { \ - constexpr int NDIM = 1; \ - __VA_ARGS__; \ - break; \ - } \ - case 2: { \ - constexpr int NDIM = 2; \ - __VA_ARGS__; \ - break; \ - } \ - case 3: { \ - constexpr int NDIM = 3; \ - __VA_ARGS__; \ - break; \ - } \ +template +void dispatch_1_2_3(int n, F&& f) { + switch (n) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; } +} -// Like MLX_SWITCH_ALL_TYPES but for booleans. -#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ - if (BOOL) { \ - constexpr bool BOOL_ALIAS = true; \ - __VA_ARGS__; \ - } else { \ - constexpr bool BOOL_ALIAS = false; \ - __VA_ARGS__; \ +template +void dispatch_bool(bool v, F&& f) { + if (v) { + f(std::true_type{}); + } else { + f(std::false_type{}); } +} -// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. -#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ - { \ - uint32_t _num_threads = NUM_THREADS; \ - if (_num_threads <= WARP_SIZE) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 2) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 4) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 8) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 16) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \ - __VA_ARGS__; \ - } \ +template +void dispatch_block_dim(int threads, F&& f) { + if (threads <= WARP_SIZE) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 2) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 4) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 8) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 16) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} // Maps CPU types to CUDA types. template diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index c71795fad..23f0b168f 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -259,21 +259,22 @@ void LayerNorm::eval_gpu( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + kernel<<>>( + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); }); }); } @@ -357,22 +358,27 @@ void LayerNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index f57f82ea8..5d6bf437d 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -144,14 +144,15 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index e32befc9c..715e5a232 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_output_array(out); encoder.launch_kernel([&, this](cudaStream_t stream) { - MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; CTYPE step = static_cast(start_ + step_) - static_cast(start_); diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 5a7c28041..a6ccd5ae9 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -111,10 +111,11 @@ void all_reduce( encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), @@ -135,10 +136,11 @@ void all_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), out.data(), block_step, insize); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 192a9b3e8..78f6b93bc 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -215,11 +215,12 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -229,7 +230,8 @@ void col_reduce_looped( constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; - auto kernel = cu::col_reduce_looped; + auto kernel = + cu::col_reduce_looped; kernel<<>>(indata, out.data(), args); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 50fe109c4..296a4e611 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -33,10 +33,11 @@ void init_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::init_reduce; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a7262bcc2..d0eb3f5c5 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/common/reduce.h" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" @@ -9,43 +11,35 @@ namespace mlx::core { -// Dispatch dynamic ndim to constexpr. -// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. -#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ - if (ndim == 1) { \ - constexpr uint32_t NDIM = 1; \ - __VA_ARGS__; \ - } else if (ndim == 2) { \ - constexpr uint32_t NDIM = 2; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t NDIM = 5; \ - __VA_ARGS__; \ +template +void dispatch_reduce_ndim(int ndim, F&& f) { + if (ndim == 1) { + f(std::integral_constant{}); + } else if (ndim == 2) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} -// Dispatch reduce ops to constexpr. -#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ - if (REDUCE == Reduce::ReduceType::And) { \ - using OP = cu::And; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Or) { \ - using OP = cu::Or; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Sum) { \ - using OP = cu::Sum; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Prod) { \ - using OP = cu::Prod; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Max) { \ - using OP = cu::Max; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Min) { \ - using OP = cu::Min; \ - __VA_ARGS__; \ - } else { \ - throw std::invalid_argument("Unknown reduce type."); \ +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { + if (reduce_type == Reduce::ReduceType::And) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Or) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Sum) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Prod) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Max) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Min) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); } +} void all_reduce( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 6a8a35311..4578dbad0 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -246,10 +246,11 @@ void row_reduce_simple( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -293,10 +294,11 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -311,10 +313,16 @@ void row_reduce_looped( // Pick the kernel auto kernel = cu::row_reduce_looped; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - MLX_SWITCH_BLOCK_DIM(threads, THREADS, { - kernel = cu::row_reduce_looped; - block.x = THREADS; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim(), + threads_constant(), + N_READS>; + block.x = threads_constant(); }); }); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 3c521b90d..7b87f2947 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -225,19 +225,20 @@ void RMSNorm::eval_gpu( encoder.set_input_array(w); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + kernel<<>>( + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); }); }); } @@ -311,22 +312,28 @@ void RMSNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 1d8307811..a7d7b27ce 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -310,12 +310,12 @@ void RoPE::eval_gpu( encoder.set_input_array(offset); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { - using DataType = cuda_type_t; - MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { - MLX_SWITCH_BOOL(forward_, FORWARD, { + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; if (single && !with_freqs) { - auto kernel = cu::rope_single; + auto kernel = cu::rope_single; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -327,7 +327,8 @@ void RoPE::eval_gpu( mat_size, dims); } else if (single) { - auto kernel = cu::rope_single_freqs; + auto kernel = + cu::rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -340,7 +341,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else if (with_freqs) { - auto kernel = cu::rope_freqs; + auto kernel = cu::rope_freqs; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; @@ -358,7 +359,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else { - auto kernel = cu::rope; + auto kernel = cu::rope; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 652e6da19..af9ddf214 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -142,17 +142,18 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5cbffc0f4..2c5599bed 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { temp.data(), size, args...)); } +struct OffsetTransform { + int nsort; + + int __device__ operator()(int i) { + return i * nsort; + } +}; + void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); @@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { using Type = cuda_type_t; auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - [nsort] __device__(int i) { return i * nsort; }); + thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. array indices( diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index e33af3c80..1d6535100 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -92,58 +92,63 @@ void ternary_op_gpu_inplace( encoder.set_input_array(c); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } }); - } else { - auto kernel = cu::ternary_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index e45144eda..4f9bac29f 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -79,8 +79,10 @@ void unary_op_gpu_inplace( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 4a3d8be30..35731f6eb 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -25,22 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) { } const char* dtype_to_cuda_type(const Dtype& dtype) { - if (dtype == float16) { - return "__half"; + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "__nv_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "cuComplex"; + default: + return "unknown"; } - if (dtype == bfloat16) { - return "__nv_bfloat16"; - } - if (dtype == complex64) { - return "cuComplex"; - } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (dtype == DTYPE) { \ - return #CPP_TYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return nullptr; } } // namespace mlx::core diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index a4448536d..9f10e6a9a 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -5,16 +5,38 @@ namespace mlx::core { const char* dtype_to_string(Dtype arg) { - if (arg == bool_) { - return "bool"; + switch (arg) { + case bool_: + return "bool"; + case int8: + return "int8"; + case int16: + return "int16"; + case int32: + return "int32"; + case int64: + return "int64"; + case uint8: + return "uint8"; + case uint16: + return "uint16"; + case uint32: + return "uint32"; + case uint64: + return "uint64"; + case float16: + return "float16"; + case bfloat16: + return "bfloat16"; + case float32: + return "float32"; + case float64: + return "float64"; + case complex64: + return "complex64"; + default: + return "unknown"; } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (DTYPE == arg) { \ - return #DTYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return "(unknown)"; } } // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 55de890f2..27fe432f6 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -1,207 +1,106 @@ // Copyright © 2025 Apple Inc. -// Copyright © Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the BSD-style license found in -// https://github.com/pytorch/executorch/blob/main/LICENSE -// -// Forked from -// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h #pragma once -#include "mlx/dtype.h" +#include -#include +#include "mlx/dtype.h" +#include "mlx/utils.h" namespace mlx::core { // Return string representation of dtype. const char* dtype_to_string(Dtype arg); -// Macros that iterate across different subsets of Dtypes. -// -// For all of these macros, the final `_` parameter is the name of another macro -// that takes two parameters: the name of a C type, and the name of the -// corresponding Dtype enumerator. -// -// Note that these macros should use fully-qualified namespaces (starting with -// `::`) to ensure that they can be called safely in any arbitrary namespace. -#define MLX_FORALL_INT_TYPES(_) \ - _(uint8_t, uint8) \ - _(uint16_t, uint16) \ - _(uint32_t, uint32) \ - _(uint64_t, uint64) \ - _(int8_t, int8) \ - _(int16_t, int16) \ - _(int32_t, int32) \ - _(int64_t, int64) +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break -#define MLX_FORALL_FLOAT_TYPES(_) \ - _(float16_t, float16) \ - _(float, float32) \ - _(double, float64) \ - _(bfloat16_t, bfloat16) +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) -// Calls the provided macro on every Dtype, providing the C type and the -// Dtype name to each call. -// -// @param _ A macro that takes two parameters: the name of a C type, and the -// name of the corresponding Dtype enumerator. -#define MLX_FORALL_DTYPES(_) \ - MLX_FORALL_INT_TYPES(_) \ - MLX_FORALL_FLOAT_TYPES(_) \ - _(bool, bool_) \ - _(complex64_t, complex64) +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) -// Maps Dtypes to C++ types. -template -struct DtypeToCppType; - -#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ - template <> \ - struct DtypeToCppType { \ - using type = CPP_TYPE; \ - }; - -MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) - -#undef SPECIALIZE_DtypeToCppType - -// Maps C++ types to Dtypes. +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. template -struct CppTypeToDtype; +struct type_identity { + using type = T; +}; -#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ - template <> \ - struct CppTypeToDtype \ - : std::integral_constant {}; +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value -MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) - -#undef SPECIALIZE_CppTypeToDtype - -// Helper macros for switch case macros (see below) -// -// These macros are not meant to be used directly. They provide an easy way to -// generate a switch statement that can handle subsets of Dtypes supported. - -#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ - case enum_type: { \ - using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ - __VA_ARGS__; \ - break; \ +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); } +} -#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ - switch (TYPE) { \ - __VA_ARGS__ \ - default: \ - throw std::invalid_argument(fmt::format( \ - "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -// Switch case macros -// -// These macros provide an easy way to generate switch statements that apply a -// common lambda function to subsets of Dtypes supported by MLX. -// The lambda function can type specialize to the ctype associated with the -// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. -// -// Arguments: -// - ADDITIONAL: Additional Dtype case to add -// - TYPE: The Dtype to handle through the switch statement -// - NAME: A name for this operation which will be used in error messages -// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. -// - ...: A statement to be applied to each Dtype case -// -// An example usage is: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { -// output.data[0] = input.data[0]; -// }); -// -// Note that these can be nested as well: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { -// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { -// output.data[0] = input.data[0]; -// }); -// }); -// -// These macros are adapted from Dispatch.h in the ATen library. The primary -// difference is that the CTYPE_ALIAS argument is exposed to users, which is -// used to alias the ctype associated with the Dtype that is being handled. - -#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ - switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } - -#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} } // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 61b9da3a2..e53a7a97f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); + dispatch_all_types(a.dtype(), [&](auto type_tag) { + print_array(os, a); + }); return os; } @@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - MLX_SWITCH_INT_TYPES_CHECKED( - dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); + dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) { + set_iinfo_limits(min, max); + }); } } // namespace mlx::core From dd4f53db63020ede8b8abf6eec91a35e92dc73c1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 1 Jul 2025 07:30:00 -0700 Subject: [PATCH 122/156] use fp32 for testing, add more complex ops (#2322) --- mlx/backend/cuda/device/unary_ops.cuh | 54 +++++++++++++++++++++++---- mlx/backend/cuda/layer_norm.cu | 2 - mlx/backend/cuda/rms_norm.cu | 1 - mlx/backend/cuda/unary.cu | 35 ++++++++--------- python/tests/cuda_skip.py | 12 +----- python/tests/mlx_tests.py | 4 ++ 6 files changed, 68 insertions(+), 40 deletions(-) diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index efa9133b1..18d769c2a 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -27,6 +27,8 @@ struct ArcCos { __device__ T operator()(T x) { return acos(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcCosh { @@ -41,6 +43,8 @@ struct ArcSin { __device__ T operator()(T x) { return asin(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcSinh { @@ -55,6 +59,8 @@ struct ArcTan { __device__ T operator()(T x) { return atan(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcTanh { @@ -261,13 +267,6 @@ struct Round { } }; -struct Rsqrt { - template - __device__ T operator()(T x) { - return rsqrt(x); - } -}; - struct Sigmoid { template __device__ T operator()(T x) { @@ -333,6 +332,29 @@ struct Sqrt { __device__ T operator()(T x) { return sqrt(x); } + + __device__ cuComplex operator()(cuComplex x) { + auto xr = cuCrealf(x); + auto xi = cuCimagf(x); + if (xr == 0.0f && xi == 0.0f) { + return {0.0f, 0.0f}; + } + auto r = cuCrealf(Abs{}(x)); + auto a = sqrt((r + xr) / 2.0f); + auto b_abs = sqrt((r - xr) / 2.0f); + auto b = copysign(b_abs, xi); + return {a, b}; + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + return rsqrt(x); + } + __device__ cuComplex operator()(cuComplex x) { + return 1.0f / Sqrt{}(x); + } }; struct Tan { @@ -365,4 +387,22 @@ struct Tanh { } }; +__device__ cuComplex ArcCos::operator()(cuComplex x) { + auto i = cuComplex{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {cuCimagf(y), -cuCrealf(y)}; +}; + +__device__ cuComplex ArcSin::operator()(cuComplex x) { + auto i = cuComplex{0.0f, 1.0f}; + auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); + return {cuCimagf(y), -cuCrealf(y)}; +}; + +__device__ cuComplex ArcTan::operator()(cuComplex x) { + auto i = cuComplex{0.0f, 1.0f}; + auto ix = i * x; + return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 23f0b168f..852cf43af 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -342,8 +342,6 @@ void LayerNormVJP::eval_gpu( encoder.add_temporary(gw_temp); } } - gw.set_data(allocator::malloc(gw.nbytes())); - gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b. if (gb.ndim() == 1 && gb.size() == axis_size) { diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 7b87f2947..7f5f9630d 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -304,7 +304,6 @@ void RMSNormVJP::eval_gpu( encoder.add_temporary(gw_temp); } } - gw.set_data(allocator::malloc(gw.nbytes())); encoder.set_input_array(x); encoder.set_input_array(w); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 4f9bac29f..74251d1f6 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -20,38 +20,35 @@ namespace cu { template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { return std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { return std::is_same_v && is_floating_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - return std::is_same_v && is_inexact_v; - } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v) { + if (std::is_same_v || std::is_same_v) { return std::is_same_v && !std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - return std::is_same_v && - (is_floating_v || std::is_same_v); + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_same_v; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cba642ca1..fce92bacb 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,25 +1,15 @@ cuda_skip = { - "TestArray.test_api", - "TestBF16.test_arg_reduction_ops", - "TestBlas.test_complex_gemm", - "TestEinsum.test_ellipses", - "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", - "TestLayers.test_group_norm", - "TestLayers.test_pooling", "TestLayers.test_quantized_embedding", - "TestLayers.test_sin_pe", - "TestLayers.test_upsample", - "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", "TestReduce.test_dtypes", - "TestUpsample.test_torch_upsample", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", # Scan NYI + "TestArray.test_api", "TestAutograd.test_cumprod_grad", "TestOps.test_scans", "TestOps.test_logcumsumexp", diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 65bd0e873..bc197b673 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -1,6 +1,10 @@ # Copyright © 2023 Apple Inc. import os + +# Use regular fp32 precision for tests +os.environ["MLX_ENABLE_TF32"] = "0" + import platform import unittest from typing import Any, Callable, List, Tuple, Union From 58f38603066b589341429efa96fc77619b82979e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 1 Jul 2025 12:12:16 -0700 Subject: [PATCH 123/156] patch bump (#2324) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index 530d0620d..5ad66e3c2 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 26 -#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_PATCH 2 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From cfb6a244ea39006febbc4f551518b53740f46b69 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 1 Jul 2025 21:27:23 -0700 Subject: [PATCH 124/156] allow parameters to be deleted (#2325) --- python/mlx/nn/layers/base.py | 6 ++++++ python/tests/test_nn.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 4a548c80d..e99943834 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -114,6 +114,12 @@ class Module(dict): super(Module, self).__setattr__(key, val) self.pop(key, None) + def __delattr__(self, name): + if (val := self.get(name, None)) is not None: + del self[name] + else: + super().__delattr__(name) + def load_weights( self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 53bcb3141..ae3fae4da 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase): m = MyModel() m.update_modules(m.leaf_modules()) + def test_parameter_deletion(self): + m = nn.Linear(32, 32) + del m.weight + self.assertFalse(hasattr(m, "weight")) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): From e76e9b87f00cff6a2fd18791d3cc7323a7444375 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 2 Jul 2025 22:04:38 +0900 Subject: [PATCH 125/156] Fix compilation error from integral_constant (#2326) --- mlx/backend/cuda/layer_norm.cu | 2 +- mlx/backend/cuda/reduce/row_reduce.cu | 6 +++--- mlx/backend/cuda/rms_norm.cu | 2 +- mlx/backend/cuda/rope.cu | 12 +++++++----- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 852cf43af..9a9fbcb37 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu( using DataType = cuda_type_t; auto kernel = cu::layer_norm_vjp< DataType, - has_w_constant(), + has_w_constant.value, block_dim(), N_READS>; kernel<<>>( diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 4578dbad0..deb4a2f91 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -319,10 +319,10 @@ void row_reduce_looped( T, U, OP, - reduce_ndim(), - threads_constant(), + reduce_ndim.value, + threads_constant.value, N_READS>; - block.x = threads_constant(); + block.x = threads_constant.value; }); }); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 7f5f9630d..fc8f4f490 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu( constexpr int N_READS = 4; auto kernel = cu::rms_norm_vjp< DataType, - has_w_constant(), + has_w_constant.value, block_dim(), N_READS>; kernel<<>>( diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index a7d7b27ce..bb9618fc4 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -315,7 +315,8 @@ void RoPE::eval_gpu( dispatch_bool(forward_, [&](auto forward) { using DataType = cuda_type_t; if (single && !with_freqs) { - auto kernel = cu::rope_single; + auto kernel = + cu::rope_single; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -327,8 +328,8 @@ void RoPE::eval_gpu( mat_size, dims); } else if (single) { - auto kernel = - cu::rope_single_freqs; + auto kernel = cu:: + rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -341,7 +342,8 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else if (with_freqs) { - auto kernel = cu::rope_freqs; + auto kernel = + cu::rope_freqs; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; @@ -359,7 +361,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else { - auto kernel = cu::rope; + auto kernel = cu::rope; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; From ec0d5db67b44916ee7706ef2ac624d642510bdac Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 2 Jul 2025 15:59:13 -0700 Subject: [PATCH 126/156] [CUDA] Switch to CUDA graphs (#2317) * cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment --- mlx/backend/common/matmul.h | 23 +- mlx/backend/cuda/arg_reduce.cu | 47 ++- mlx/backend/cuda/binary.cu | 154 +++++----- mlx/backend/cuda/binary_two.cu | 167 +++++----- mlx/backend/cuda/compiled.cpp | 21 +- mlx/backend/cuda/copy/copy_contiguous.cu | 37 +-- mlx/backend/cuda/copy/copy_general.cu | 82 ++--- mlx/backend/cuda/copy/copy_general_dynamic.cu | 83 ++--- mlx/backend/cuda/copy/copy_general_input.cu | 72 ++--- mlx/backend/cuda/device.cpp | 289 ++++++++++++++---- mlx/backend/cuda/device.h | 170 ++++++----- mlx/backend/cuda/eval.cpp | 28 +- mlx/backend/cuda/event.cu | 20 +- mlx/backend/cuda/indexing.cpp | 154 +++++----- mlx/backend/cuda/jit_module.cpp | 65 +--- mlx/backend/cuda/jit_module.h | 76 +++-- mlx/backend/cuda/kernel_utils.cuh | 9 +- mlx/backend/cuda/layer_norm.cu | 106 ++++--- mlx/backend/cuda/logsumexp.cu | 22 +- mlx/backend/cuda/matmul.cpp | 73 ++--- mlx/backend/cuda/primitives.cu | 30 +- mlx/backend/cuda/random.cu | 61 ++-- mlx/backend/cuda/reduce/all_reduce.cu | 51 ++-- mlx/backend/cuda/reduce/col_reduce.cu | 36 ++- mlx/backend/cuda/reduce/init_reduce.cu | 22 +- mlx/backend/cuda/reduce/row_reduce.cu | 108 ++++--- mlx/backend/cuda/rms_norm.cu | 93 +++--- mlx/backend/cuda/rope.cu | 151 ++++----- mlx/backend/cuda/softmax.cu | 28 +- mlx/backend/cuda/sort.cu | 157 +++++----- mlx/backend/cuda/ternary.cu | 127 ++++---- mlx/backend/cuda/unary.cu | 97 ++++-- mlx/backend/cuda/utils.cpp | 8 + mlx/backend/cuda/utils.h | 2 + mlx/linalg.cpp | 2 +- python/tests/test_load.py | 2 + 36 files changed, 1461 insertions(+), 1212 deletions(-) diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h index 2e0261a30..2faf256d1 100644 --- a/mlx/backend/common/matmul.h +++ b/mlx/backend/common/matmul.h @@ -12,16 +12,11 @@ namespace mlx::core { inline std::tuple collapse_batches( const array& a, const array& b) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - if (A_bshape != B_bshape) { - std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; @@ -42,17 +37,11 @@ inline std::tuple collapse_batches( inline std::tuple collapse_batches(const array& a, const array& b, const array& c) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; - if (A_bshape != B_bshape || A_bshape != C_bshape) { - std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 90f8561c1..ad942a406 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -151,30 +151,29 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { - using T = cuda_type_t; - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - auto kernel = - cu::arg_reduce_general, block_dim(), N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = cu:: - arg_reduce_general, block_dim(), N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu::arg_reduce_general, block_dim(), N_READS>; + } + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim(), + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 8e476d30f..d9b9fd8af 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -139,90 +139,92 @@ void binary_op_gpu_inplace( encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out.data(), out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out.dtype()))); + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 0a68e5f1d..9582b0378 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -137,98 +137,101 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out_a.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out_a.data(), out_b.data(), out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, - out_a.data_size(), - out_a.shape(), - out_a.strides(), - large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out_a.dtype()))); + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 1aa7ecb92..21257e5dd 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/graph_utils.h" #include "mlx/primitives.h" @@ -178,6 +179,7 @@ void Compiled::eval_gpu( // Whether to use large index. bool large = compiled_use_large_index(inputs, outputs, contiguous); + cu::KernelArgs args; // Put inputs. int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { @@ -185,26 +187,26 @@ void Compiled::eval_gpu( continue; } const auto& x = inputs[i]; - mod.append_arg(x); + args.append(x); if (!contiguous && !is_scalar(x)) { - mod.append_arg(strides_vec[strides_index++]); + args.append_ptr(strides_vec[strides_index++].data()); } } // Put outputs. compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { - mod.append_arg(x); + args.append(x); } // Put shape and size. if (!contiguous) { - mod.append_arg(shape); + args.append_ptr(shape.data()); } if (large) { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } else { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } // Launch kernel. @@ -222,9 +224,10 @@ void Compiled::eval_gpu( for (const auto& out : outputs) { encoder.set_output_array(out); } - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, outputs[0], large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 15858ded0..408350129 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -35,24 +35,25 @@ void copy_contiguous( array& out, int64_t in_offset, int64_t out_offset) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); - }); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index b2703e4bf..5c7f9f954 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -55,50 +55,54 @@ void copy_general( const Shape& shape, const Strides& strides_in, const Strides& strides_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto kernel = - cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large()); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; auto [num_blocks, block_dims] = get_launch_args( kernel, data_size, shape, out.strides(), large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 68ad005d2..1b643111f 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -61,54 +61,55 @@ void copy_general_dynamic( const Strides& strides_out, const array& dynamic_offset_in, const array& dynamic_offset_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::copy_gg_dynamic_nd< - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + copy_gg_dynamic_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, + const_param(shape), + const_param(strides_in), + const_param(strides_out), dynamic_offset_in.data(), dynamic_offset_out.data()); - } - }); - }); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index d83ba0854..1ac7712e6 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -50,45 +50,49 @@ void copy_general_input( int64_t offset_out, const Shape& shape, const Strides& strides_in) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::copy_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index ba31c0e45..fff752fe5 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -2,38 +2,23 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/worker.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/utils.h" #include #include #include +#include namespace mlx::core { +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +// This should be less than 255 +constexpr int default_max_nodes_per_graph = 20; + +constexpr int max_graph_cache_size = 100; + namespace cu { -DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} - -void DeviceStream::synchronize() { - cudaStreamSynchronize(stream_); -} - -cudaStream_t DeviceStream::schedule_cuda_stream() { - // TODO: Return a stream that maximizes parallelism. - return stream_; -} - -cudaStream_t DeviceStream::last_cuda_stream() { - return stream_; -} - -CommandEncoder& DeviceStream::get_encoder() { - if (!encoder_) { - encoder_ = std::make_unique(*this); - } - return *encoder_; -} - Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); @@ -67,49 +52,253 @@ void Device::make_current() { } } -DeviceStream& Device::get_stream(Stream s) { - auto it = streams_.find(s.index); - if (it == streams_.end()) { - it = streams_.try_emplace(s.index, *this).first; +CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); + CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph( + enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal)); +} + +CommandEncoder::CaptureContext::~CaptureContext() { + CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + size_t num_nodes; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); + if (num_nodes == 1) { + cudaGraphNode_t captured_node; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); + enc.insert_graph_dependencies(GraphNode{node, 'K'}); + } else { + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); + enc.insert_graph_dependencies(GraphNode{node, 'G'}); + } + CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); +} + +CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) + : enc(enc) { + enc.in_concurrent_ = true; +} + +CommandEncoder::ConcurrentContext::~ConcurrentContext() { + enc.in_concurrent_ = false; + + // Use an empty graph node for synchronization + CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; + enc.empty_node_count_++; + CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); + + // Insert the concurrent -> empty node dependencies + for (auto& from : enc.concurrent_nodes_) { + enc.from_nodes_.push_back(from.node); + enc.to_nodes_.push_back(empty.node); + enc.graph_key_ += from.id; + enc.graph_key_ += from.node_type; + enc.graph_key_ += empty.id; + enc.graph_key_ += empty.node_type; + } + + // Insert the input -> concurrent node dependencies without updating output + // nodes + auto outputs = std::move(enc.active_outputs_); + enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_)); + + // Update output node to be the empty node + for (auto o : outputs) { + enc.node_map_.emplace(o, empty).first->second = empty; + } +} + +void CommandEncoder::insert_graph_dependencies(GraphNode node) { + if (node.node_type == 'G') { + graph_node_count_++; + } + node.id = std::to_string(node_count_++); + if (in_concurrent_) { + concurrent_nodes_.push_back(std::move(node)); + } else { + std::vector nodes; + nodes.push_back(std::move(node)); + insert_graph_dependencies(std::move(nodes)); + } +} + +void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + std::vector deps; + { + // Dependencies must be added in the same order to produce a consistent + // topology + std::unordered_set set_deps; + for (auto d : active_deps_) { + if (auto it = node_map_.find(d); it != node_map_.end()) { + auto [_, inserted] = set_deps.insert(it->second.node); + if (inserted) { + deps.push_back(it->second); + } + } + } + } + active_deps_.clear(); + + for (auto o : active_outputs_) { + for (auto& node : nodes) { + node_map_.emplace(o, node).first->second = node; + } + } + active_outputs_.clear(); + + for (auto& from : deps) { + for (auto& to : nodes) { + from_nodes_.push_back(from.node); + to_nodes_.push_back(to.node); + graph_key_ += from.id; + graph_key_ += from.node_type; + graph_key_ += to.id; + graph_key_ += to.node_type; + } + } +} + +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; } return it->second; } -CommandEncoder::CommandEncoder(DeviceStream& s) - : device_(s.device()), stream_(s) {} +CommandEncoder::CommandEncoder(Device& d) : stream_(d) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); +} + +void clear_graphs(std::unordered_map& graphs) { + for (auto& [_, graph_exec] : graphs) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + } + graphs.clear(); +} + +CommandEncoder::~CommandEncoder() { + clear_graphs(graph_cache_); +} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); } -void CommandEncoder::end_encoding() { - if (!temporaries_.empty()) { - add_completed_handler([temporaries = std::move(temporaries_)]() {}); - } +void CommandEncoder::set_input_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); +} - // There is no kernel running, run completion handlers immediately. - if (!has_gpu_work_) { - worker_.consume_in_this_thread(); - return; - } - has_gpu_work_ = false; +void CommandEncoder::set_output_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); + active_outputs_.push_back(id); +} - // Put completion handlers in a batch. - worker_.end_batch(); - - // Signaling kernel completion is expensive, delay until enough batches. - // TODO: This number is arbitrarily picked, profile for a better stragety. - if (worker_.uncommited_batches() > 8) { +void CommandEncoder::maybe_commit() { + if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) { commit(); } } +void CommandEncoder::add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + cudaKernelNodeParams kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDim = grid_dim; + kernel_params.blockDim = block_dim; + kernel_params.kernelParams = params; + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDimX = grid_dim.x; + kernel_params.gridDimY = grid_dim.y; + kernel_params.gridDimZ = grid_dim.z; + kernel_params.blockDimX = block_dim.x; + kernel_params.blockDimY = block_dim.y; + kernel_params.blockDimZ = block_dim.z; + kernel_params.kernelParams = params; + CUgraphNode node; + CHECK_CUDA_ERROR( + cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + void CommandEncoder::commit() { - worker_.commit(stream_.last_cuda_stream()); + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + if (node_count_ > 0) { + if (!from_nodes_.empty()) { + CHECK_CUDA_ERROR(cudaGraphAddDependencies( + graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); + } + // TODO smarter cache policy + if (graph_cache_.size() > max_graph_cache_size) { + clear_graphs(graph_cache_); + } + + graph_key_ += "."; + graph_key_ += std::to_string(node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(graph_node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(empty_node_count_); + auto [it, _] = graph_cache_.emplace(graph_key_, nullptr); + auto& graph_exec = it->second; + + if (graph_exec != NULL) { + cudaGraphExecUpdateResultInfo update_result; + cudaGraphExecUpdate(graph_exec, graph_, &update_result); + if (update_result.result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + graph_exec = NULL; + } + } + if (graph_exec == NULL) { + CHECK_CUDA_ERROR( + cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); + } + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + + // Reset state + node_count_ = 0; + graph_node_count_ = 0; + from_nodes_.clear(); + to_nodes_.clear(); + graph_key_.clear(); + node_map_.clear(); + CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + } + + // Put completion handlers in a batch. + worker_.end_batch(); + worker_.commit(stream_); } void CommandEncoder::synchronize() { - stream().synchronize(); + cudaStreamSynchronize(stream_); auto p = std::make_shared>(); std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); @@ -127,12 +316,8 @@ Device& device(mlx::core::Device device) { return it->second; } -DeviceStream& get_stream(Stream s) { - return device(s.device).get_stream(s); -} - CommandEncoder& get_command_encoder(Stream s) { - return get_stream(s).get_encoder(); + return device(s.device).get_command_encoder(s); } } // namespace cu diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 744f77f62..4ebdae55c 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -7,41 +7,108 @@ #include "mlx/stream.h" #include +#include #include #include namespace mlx::core::cu { -class Device; -class CommandEncoder; - -class DeviceStream { +class CommandEncoder { public: - explicit DeviceStream(Device& device); + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + cudaGraph_t graph; + CommandEncoder& enc; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; - DeviceStream(const DeviceStream&) = delete; - DeviceStream& operator=(const DeviceStream&) = delete; + explicit CommandEncoder(Device& d); + ~CommandEncoder(); - // Wait until kernels in the stream complete. - void synchronize(); + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; - // Return a cuda stream for launching kernels. - cudaStream_t schedule_cuda_stream(); - - // Return the last cuda stream used. - cudaStream_t last_cuda_stream(); - - CommandEncoder& get_encoder(); - - Device& device() { - return device_; + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; } + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void + add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + } + + void add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params); + + void + add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void maybe_commit(); + void commit(); + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + private: - Device& device_; + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // G = subgraph + char node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + CudaStream stream_; - std::unique_ptr encoder_; + cudaGraph_t graph_; + Worker worker_; + char node_count_{0}; + char graph_node_count_{0}; + char empty_node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + std::unordered_map graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; }; class Device { @@ -55,7 +122,7 @@ class Device { // Make this device the current cuda device, required by some cuda calls. void make_current(); - DeviceStream& get_stream(Stream s); + CommandEncoder& get_command_encoder(Stream s); int cuda_device() const { return device_; @@ -75,67 +142,10 @@ class Device { int compute_capability_major_; int compute_capability_minor_; cublasLtHandle_t lt_; - std::unordered_map streams_; -}; - -class CommandEncoder { - public: - explicit CommandEncoder(DeviceStream& stream); - - CommandEncoder(const CommandEncoder&) = delete; - CommandEncoder& operator=(const CommandEncoder&) = delete; - - void set_input_array(const array& arr) {} - void set_output_array(const array& arr) {} - - void add_temporary(const array& arr) { - temporaries_.push_back(arr.data_shared_ptr()); - } - - void add_completed_handler(std::function task); - void end_encoding(); - void commit(); - - // Schedule a cuda stream for |fun| to launch kernels, and check error - // afterwards. - template - void launch_kernel(F&& fun) { - launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); - } - - template - void launch_kernel(cudaStream_t stream, F&& fun) { - device_.make_current(); - fun(stream); - check_cuda_error("kernel launch", cudaGetLastError()); - has_gpu_work_ = true; - } - - Device& device() { - return device_; - } - - DeviceStream& stream() { - return stream_; - } - - bool has_gpu_work() const { - return has_gpu_work_; - } - - // Wait until kernels and completion handlers are finished - void synchronize(); - - private: - Device& device_; - DeviceStream& stream_; - Worker worker_; - bool has_gpu_work_{false}; - std::vector> temporaries_; + std::unordered_map encoders_; }; Device& device(mlx::core::Device device); -DeviceStream& get_stream(Stream s); CommandEncoder& get_command_encoder(Stream s); // Return an execution policy that does not sync for result. diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 21b019cd8..40beb12d2 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -37,22 +37,20 @@ void eval(array& arr) { } auto& encoder = cu::get_command_encoder(arr.primitive().stream()); - if (encoder.has_gpu_work()) { - // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); } - encoder.end_encoding(); + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + encoder.maybe_commit(); } void finalize(Stream s) { diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index 9fc5c641b..afa032a83 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) { if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this]() mutable { wait(); }); } else { - wait(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + wait(enc.stream()); } } @@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) { if (s.device == mlx::core::Device::cpu) { throw std::runtime_error("CudaEvent can not wait on cpu stream."); } else { - record(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + record(enc.stream()); } } @@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.commit(); + wait(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } @@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.commit(); + signal(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 65a175fbd..4b03a604e 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -3,13 +3,16 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include "cuda_jit_sources.h" +#include #include +#include #include #include @@ -22,7 +25,7 @@ namespace { constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; void append_indices_arg( - cu::JitModule& mod, + cu::KernelArgs& args, const std::vector& inputs, int nidx, int idx_ndim) { @@ -30,7 +33,7 @@ void append_indices_arg( for (int i = 0; i < nidx; ++i) { indices[i] = inputs[i + 1].data(); } - mod.append_arg(std::move(indices)); + args.append(std::move(indices)); std::vector indices_shape(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -38,7 +41,7 @@ void append_indices_arg( idx_ndim, indices_shape.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_shape)); + args.append(std::move(indices_shape)); std::vector indices_strides(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -46,7 +49,7 @@ void append_indices_arg( idx_ndim, indices_strides.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_strides)); + args.append(std::move(indices_strides)); } } // namespace @@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_gather, std::move(kernel_names)); }); - mod.append_arg(src); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(out); if (large) { - mod.append_arg(out.size()); + args.append(out.size()); } else { - mod.append_arg(out.size()); + args.append(out.size()); } - mod.append_ndim_arg(src.shape()); - mod.append_ndim_arg(src.strides()); - mod.append_arg(src.ndim()); - mod.append_ndim_arg(slice_sizes_); - mod.append_arg(slice_size); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + args.append_ndim(slice_sizes_); + args.append(slice_size); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::gather<{}, {}, {}, {}, {}>", @@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, out, large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_scatter, std::move(kernel_names)); }); - mod.append_arg(upd); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(out); if (large) { - mod.append_arg(upd.size()); + args.append(upd.size()); } else { - mod.append_arg(upd.size()); + args.append(upd.size()); } - mod.append_ndim_arg(upd.shape()); - mod.append_ndim_arg(upd.strides()); - mod.append_arg(upd.ndim()); + args.append_ndim(upd.shape()); + args.append_ndim(upd.strides()); + args.append(upd.ndim()); if (large) { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } - mod.append_ndim_arg(out.shape()); - mod.append_ndim_arg(out.strides()); - mod.append_arg(out.ndim()); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(out.shape()); + args.append_ndim(out.strides()); + args.append(out.ndim()); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", @@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, upd, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(src); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(src.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(src.shape(axis_)); - mod.append_arg(src.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(src.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(src.shape(axis_)); + args.append(src.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", @@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(upd); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(upd.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(out.shape(axis_)); - mod.append_arg(upd.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(upd.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(out.shape(axis_)); + args.append(upd.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", @@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index af8f7dc75..5bc56b25e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) { } } -#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd)) - -void check_cu_error(const char* name, CUresult err) { - if (err != CUDA_SUCCESS) { - const char* err_str = "Unknown error"; - cuGetErrorString(err, &err_str); - throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); - } -} - // Return the location of the CUDA toolkit. const std::string& cuda_home() { static std::string home = []() -> std::string { @@ -280,60 +270,13 @@ JitModule::JitModule( // Load kernels. for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; - CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); kernels_[name] = kernel; } } JitModule::~JitModule() { - CHECK_CU_ERROR(cuModuleUnload(module_)); -} - -void JitModule::launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread) { - CUfunction kernel = get_kernel(kernel_name); - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); - int _, block_dim; - CHECK_CU_ERROR( - cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); - if (block_dim > nthreads) { - block_dim = nthreads; - } - Dims num_blocks{1, 1, 1}; - if (large) { - num_blocks = - get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread); - std::get<0>(num_blocks) = - (std::get<0>(num_blocks) + block_dim - 1) / block_dim; - } else { - std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim; - } - launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1}); -} - -void JitModule::launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims) { - CHECK_CU_ERROR(cuLaunchKernel( - kernel, - std::get<0>(num_blocks), - std::get<1>(num_blocks), - std::get<2>(num_blocks), - std::get<0>(block_dims), - std::get<1>(block_dims), - std::get<2>(block_dims), - 0, - stream, - args_.data(), - nullptr)); - args_.clear(); - storage_.clear(); + CHECK_CUDA_ERROR(cuModuleUnload(module_)); } CUfunction JitModule::get_kernel(const std::string& kernel_name) { @@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) { return it->second; } -void JitModule::append_ptr_arg(const void* v) { - args_.push_back(const_cast(v)); -} - JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index bbfaa45b0..57da7c87e 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -4,6 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include @@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair< /* kernel names */ std::vector>; using KernelBuilder = std::function; -class JitModule { - public: - JitModule( - Device& device, - const std::string& module_name, - const KernelBuilder& builder); - ~JitModule(); +struct KernelArgs { + void** args() { + return args_.data(); + } - JitModule(const JitModule&) = delete; - JitModule& operator=(const JitModule&) = delete; - - void append_arg(const array& a) { - append_arg(reinterpret_cast(a.data())); + void append(const array& a) { + append(reinterpret_cast(a.data())); } template - void append_arg(T val) { + void append(T val) { storage_.emplace_back(val); - append_ptr_arg(&storage_.back()); + append_ptr(&storage_.back()); } template - void append_arg(std::vector vec) { + void append(std::vector vec) { if (vec.empty()) { // The nullptr can not be used as arg, pass something not null. - append_arg(std::monostate{}); + append(std::monostate{}); } else { - append_ptr_arg(vec.data()); + append_ptr(vec.data()); storage_.emplace_back(std::move(vec)); } } // Make sure the arg is copied to an array with size of NDIM. template - void append_ndim_arg(const std::vector& vec) { + void append_ndim(std::vector vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } - std::vector copied(NDIM); - std::copy(vec.begin(), vec.end(), copied.data()); - append_arg(std::move(copied)); + vec.resize(NDIM); + append(std::move(vec)); } - // Launch kernel with |kernel_name| that each thread works on - // |work_per_thread| elements of |arr|. - void launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread = 1); - - void launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims); - - CUfunction get_kernel(const std::string& kernel_name); + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } private: - void append_ptr_arg(const void* v); - - CUmodule module_{nullptr}; - std::unordered_map kernels_; std::vector args_; // The cuLaunchKernel API requires passing pointers to arguments so store @@ -105,6 +82,23 @@ class JitModule { std::deque storage_; }; +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel(const std::string& kernel_name); + + private: + CUmodule module_{nullptr}; + std::unordered_map kernels_; +}; + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b0058b618..eeaf916c1 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -12,6 +12,7 @@ #include "mlx/backend/cuda/device/utils.cuh" #include +#include #include #include #include @@ -120,7 +121,13 @@ std::pair get_grid_and_block(int dim0, int dim1, int dim2); template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; - CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } return block_dim; } diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 9a9fbcb37..5fbf949d7 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -258,23 +258,23 @@ void LayerNorm::eval_gpu( encoder.set_input_array(w); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); }); }); } @@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; const array& b = inputs[2]; - auto [g, g_copied] = check_input(inputs[3]); + bool g_copied; + auto g = check_input(inputs[3], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu( // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + bool g_in_gw = false; if (has_w) { if (!g_in_gx && donate_g) { + g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); @@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu( } } - // Finish with the gradient for b in case we had a b. - if (gb.ndim() == 1 && gb.size() == axis_size) { + // The gradient for b in case we had a b. + bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); + if (has_gb) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); } + // Insert dependency if `g` was donated + if ((g_in_gx || g_in_gw) && has_gb) { + encoder.set_input_array(gb); + } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 5d6bf437d..afc52826f 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index c32cecc03..e11c68b7d 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -42,7 +42,8 @@ class MatMul { int64_t ldb, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) { + int64_t b_batch_stride) + : handle_(device.lt_handle()) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); @@ -147,7 +148,7 @@ class MatMul { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { int ret = 0; CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - encoder.device().lt_handle(), + handle_, matmul_desc_, a_desc_, b_desc_, @@ -172,25 +173,24 @@ class MatMul { workspace_ptr = workspace.data(); } - encoder.launch_kernel([&](cudaStream_t stream) { - CHECK_CUBLAS_ERROR(cublasLtMatmul( - encoder.device().lt_handle(), - matmul_desc_, - &alpha, - a, - a_desc_, - b, - b_desc_, - &beta, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - stream)); - }); + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); } private: @@ -259,6 +259,7 @@ class MatMul { return desc; } + cublasLtHandle_t handle_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr}; @@ -273,7 +274,7 @@ class MatMul { namespace { std::tuple -check_transpose(std::vector& copies, const Stream& s, const array& arr) { +check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && stx == arr.shape(-1)) { @@ -283,7 +284,7 @@ check_transpose(std::vector& copies, const Stream& s, const array& arr) { } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); + enc.add_temporary(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } } @@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, @@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 715e5a232..18fa45a33 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); + auto& encoder = cu::get_command_encoder(stream()); encoder.set_output_array(out); - encoder.launch_kernel([&, this](cudaStream_t stream) { - dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - using OutType = cuda_type_t; - CTYPE step = - static_cast(start_ + step_) - static_cast(start_); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(out.data_size()), - thrust::device_pointer_cast(out.data()), - cu::Arange{ - static_cast(start_), static_cast(step)}); - }); + auto capture = encoder.capture_context(); + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(encoder.stream()), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); }); } diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 0cb550d56..7221af356 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(keys); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dim3 grid_dims{num_keys, half_size + odd}; - int64_t total = grid_dims.x * grid_dims.y; - int32_t threads_y = 1; - while ((total / threads_y) >= (1U << 31)) { - threads_y *= 2; - } - int32_t threads_x = cuda::ceil_div(total, threads_y); - auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); - if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key); - } else { - cu::rbits<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key, - keys.ndim(), - const_param(keys.shape()), - const_param(keys.strides())); - } - }); + dim3 grid_dims{num_keys, half_size + odd}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); + auto& stream = encoder.stream(); + if (keys.flags().row_contiguous) { + encoder.add_kernel_node( + cu::rbitsc, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key); + } else { + encoder.add_kernel_node( + cu::rbits, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index a6ccd5ae9..3419d61cb 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -110,19 +110,20 @@ void all_reduce( intermediate.set_data(allocator::malloc(intermediate.nbytes())); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), - intermediate.data(), - block_step, - insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + intermediate.data(), + block_step, + insize); }); }); @@ -135,16 +136,20 @@ void all_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), out.data(), block_step, insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + out.data(), + block_step, + insize); }); }); } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 78f6b93bc..910fa0379 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -214,26 +214,24 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); - - constexpr int N_READS = 4; - constexpr int BM = 32; - constexpr int BN = 32; - dim3 grid = output_grid_for_col_reduce(out, args, BN); - int blocks = BM * BN / N_READS; - auto kernel = - cu::col_reduce_looped; - kernel<<>>(indata, out.data(), args); - }); + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = + cu::col_reduce_looped; + encoder.add_kernel_node( + kernel, grid, blocks, indata, out.data(), args); }); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 296a4e611..649d80190 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -32,18 +32,16 @@ void init_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::init_reduce; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); - grid.x = (grid.x + 1023) / 1024; - kernel<<>>(out.data(), out.size()); - }); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index deb4a2f91..e57f18668 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -245,34 +245,32 @@ void row_reduce_simple( // 2 passes. Something like 32 * out.size() and then do a warp reduce. encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Calculate the grid and block dims - size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Pick the kernel - auto kernel = cu::row_reduce_simple; - if (grid.x >= 1024) { - grid.x = (grid.x + 1) / 2; - kernel = cu::row_reduce_simple; - } + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } - // Launch - kernel<<>>( - indata, out.data(), out.size(), plan.shape.back()); - }); + int size = plan.shape.back(); + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), size); }); }); } @@ -293,43 +291,39 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Calculate the grid and block dims - args.sort_access_pattern(in, axes); - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - size_t reductions = (args.row_size + N_READS - 1) / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); - - // Pick the kernel - auto kernel = cu::row_reduce_looped; - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - dispatch_block_dim(threads, [&](auto threads_constant) { - kernel = cu::row_reduce_looped< - T, - U, - OP, - reduce_ndim.value, - threads_constant.value, - N_READS>; - block.x = threads_constant.value; - }); + // Pick the kernel + auto kernel = cu::row_reduce_looped; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim.value, + threads_constant.value, + N_READS>; + block.x = threads_constant.value; }); - - // Launch - kernel<<>>( - indata, out.data(), out.size(), args); }); + + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index fc8f4f490..5ee1d3386 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -224,21 +224,21 @@ void RMSNorm::eval_gpu( encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); }); }); } @@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; - auto [g, g_copied] = check_input(inputs[2]); + bool g_copied; + auto g = check_input(inputs[2], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu( encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - auto kernel = cu::rms_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index bb9618fc4..517cddfe0 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -308,76 +308,89 @@ void RoPE::eval_gpu( auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); + if (with_freqs) { + encoder.set_input_array(inputs[2]); + } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { - dispatch_bool(traditional_, [&](auto traditional) { - dispatch_bool(forward_, [&](auto forward) { - using DataType = cuda_type_t; - if (single && !with_freqs) { - auto kernel = - cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - mat_size, - dims); - } else if (single) { - auto kernel = cu:: - rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - mat_size, - dims, - inputs[2].strides(0)); - } else if (with_freqs) { - auto kernel = - cu::rope_freqs; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims, - inputs[2].strides(0)); - } else { - auto kernel = cu::rope; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims); - } - }); + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = + cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = + cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } }); }); }); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index af9ddf214..fd807bd8d 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 2c5599bed..379c55706 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) { return out; } -template -void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( - temp.data(), size, args...)); -} - -template -void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( - temp.data(), size, args...)); -} - struct OffsetTransform { int nsort; @@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using Type = cuda_type_t; - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), OffsetTransform{nsort}); - if (argsort) { - // Indices in the sorted dimension. - array indices( - allocator::malloc(out.nbytes()), in.shape(), out.dtype()); - encoder.add_temporary(indices); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(indices.data_size()), - thrust::device_pointer_cast(indices.data()), - ModOp{static_cast(nsort)}); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + auto& stream = encoder.stream(); + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), OffsetTransform{nsort}); + if (argsort) { + // Indices in the sorted dimension. + array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); - // In argsort though we don't need the result of sorted values, the - // API requires us to provide an array to store it. - array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); - encoder.add_temporary(discard); + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); - segmented_sort_pairs( - encoder, - in.data(), - discard.data(), - indices.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } else { - segmented_sort( - encoder, - in.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + nullptr, + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + nullptr, + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } - }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } }); if (!is_segmented_sort) { diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 1d6535100..aa6523f27 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -91,73 +91,80 @@ void ternary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(out.dtype(), [&](auto type_tag) { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; - auto topt = get_ternary_op_type(a, b, c); - if (topt == TernaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); - }); - } else { - auto kernel = cu::ternary_g; + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), c.data(), out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size()); - }); - } - }); + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } }); } diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 74251d1f6..3f1a62d24 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -9,14 +9,38 @@ #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include #include -#include -#include namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(in[index]); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); + out[index] = Op{}(in[idx]); + } +} + template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || @@ -71,38 +95,61 @@ void unary_op_gpu_inplace( if (in.size() == 0) { return; } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_unary_op()) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - auto policy = cu::thrust_policy(stream); - auto in_ptr = thrust::device_pointer_cast(in.data()); - auto out_ptr = thrust::device_pointer_cast(out.data()); - if (in.flags().contiguous) { - thrust::transform( - policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + using IdxT = std::conditional_t; + if (contig) { + auto kernel = cu::unary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size()); } else { auto [shape, strides] = collapse_contiguous_dims(in); - auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.size(), shape, strides); - thrust::transform(policy, in_begin, in_end, out_ptr, Op()); + auto kernel = cu::unary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(strides), + shape.size()); } - } else { - throw std::runtime_error(fmt::format( - "Can not do unary op {} on input of {} with output of {}.", - op, - dtype_to_string(in.dtype()), - dtype_to_string(out.dtype()))); - } - }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 35731f6eb..cc05428a8 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +void check_cuda_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + const char* dtype_to_cuda_type(const Dtype& dtype) { switch (dtype) { case bool_: diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6d98cdcd5..bfb02c5b6 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,7 @@ #pragma once +#include #include namespace mlx::core { @@ -33,6 +34,7 @@ class CudaStream { // Throw exception if the cuda API does not succeed. void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 144f9a880..ff3208e1e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { perm = expand_dims(perm, -1, s); take_axis -= 1; } - auto pb = take_along_axis(b, perm, take_axis); + auto pb = take_along_axis(b, perm, take_axis, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); return solve_triangular(luf[2], y, /* upper = */ true, s); } diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 35f7016c5..840d3b471 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) + mx.synchronize() load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) + mx.synchronize() load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary) From 8917022deb609511778dc93c193480967e43b777 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 2 Jul 2025 19:37:58 -0700 Subject: [PATCH 127/156] fix graphs for older cuda (#2328) --- mlx/backend/cuda/device.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index fff752fe5..4129563af 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -54,8 +54,8 @@ void Device::make_current() { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); - CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph( - enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal)); + CHECK_CUDA_ERROR( + cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } CommandEncoder::CaptureContext::~CaptureContext() { From 0e0d9ac522aea5b25f7663de7a4b00db27bf85c0 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 6 Jul 2025 00:33:29 +0900 Subject: [PATCH 128/156] [CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329) --- mlx/backend/cuda/device.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 4129563af..638d68727 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -15,7 +15,12 @@ namespace mlx::core { // This should be less than 255 constexpr int default_max_nodes_per_graph = 20; -constexpr int max_graph_cache_size = 100; +int cuda_graph_cache_size() { + static int cache_size = []() { + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + }(); + return cache_size; +} namespace cu { @@ -252,10 +257,6 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR(cudaGraphAddDependencies( graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); } - // TODO smarter cache policy - if (graph_cache_.size() > max_graph_cache_size) { - clear_graphs(graph_cache_); - } graph_key_ += "."; graph_key_ += std::to_string(node_count_); @@ -281,6 +282,11 @@ void CommandEncoder::commit() { } CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + // TODO smarter cache policy + if (graph_cache_.size() > cuda_graph_cache_size()) { + clear_graphs(graph_cache_); + } + // Reset state node_count_ = 0; graph_node_count_ = 0; From f5299f72cd20d258eec96cb9b81277226a2afbcd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 7 Jul 2025 06:06:01 -0700 Subject: [PATCH 129/156] Fix layernorm race condition (#2340) --- mlx/backend/metal/kernels/layer_norm.metal | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 06b8be55f..ea77b53dc 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -31,6 +31,7 @@ inline void threadgroup_sum( for (int i = 0; i < N; i++) { x[i] = simd_sum(x[i]); } + threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == 0) { for (int i = 0; i < N; i++) { xs[N * simd_group_id + i] = x[i]; From 19facd4b2084318b54d878032f88d9518310d123 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 7 Jul 2025 22:06:45 +0900 Subject: [PATCH 130/156] Build with all cpu cores by default (#2336) --- .circleci/config.yml | 28 +++++++--------------------- docs/src/install.rst | 8 ++++---- python/mlx/extension.py | 6 +----- setup.py | 6 +----- 4 files changed, 13 insertions(+), 35 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 205a930af..293cdce79 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -41,7 +41,7 @@ jobs: pip install --upgrade pip pip install --upgrade cmake pip install -r docs/requirements.txt - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v + pip install . -v - when: condition: not: << parameters.upload-docs >> @@ -97,10 +97,8 @@ jobs: name: Install Python package command: | CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py build_ext --inplace CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py develop - run: name: Generate package stubs @@ -157,8 +155,7 @@ jobs: name: Install Python package command: | source env/bin/activate - DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ pip install -e . -v - run: name: Generate package stubs @@ -208,8 +205,7 @@ jobs: name: Run Python tests with JIT command: | source env/bin/activate - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ + CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ pip install -e . -v LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ METAL_DEBUG_ERROR_MODE=0 \ @@ -228,8 +224,7 @@ jobs: sudo apt-get install libblas-dev liblapack-dev liblapacke-dev python -m venv env source env/bin/activate - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ pip install -e ".[dev]" - run: name: Run Python tests @@ -278,7 +273,6 @@ jobs: command: | source env/bin/activate env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ pip install . -v - run: name: Generate package stubs @@ -290,9 +284,7 @@ jobs: name: Build Python package command: | source env/bin/activate - << parameters.build_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ - python -m build -w + << parameters.build_env >> python -m build -w - when: condition: << parameters.build_env >> steps: @@ -340,14 +332,10 @@ jobs: pip install patchelf pip install build pip install twine - << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - pip install . -v + << parameters.extra_env >> pip install . -v pip install typing_extensions python setup.py generate_stubs - << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ - python -m build --wheel + << parameters.extra_env >> python -m build --wheel auditwheel show dist/* auditwheel repair dist/* --plat manylinux_2_31_x86_64 - run: @@ -383,12 +371,10 @@ jobs: pip install build pip install twine << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ pip install ".[dev]" -v python setup.py generate_stubs << parameters.extra_env >> \ - CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ python -m build --wheel bash python/scripts/repair_cuda.sh diff --git a/docs/src/install.rst b/docs/src/install.rst index 22de94f90..a50b6a71d 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -88,20 +88,20 @@ Then simply build and install MLX using pip: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 pip install . + pip install . For developing, install the package with development dependencies, and use an editable install: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]" + pip install -e ".[dev]" Once the development dependencies are installed, you can build faster with: .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace + python setup.py build_ext --inplace Run the tests with: @@ -262,7 +262,7 @@ When building either the Python or C++ APIs make sure to pass the cmake flag .. code-block:: shell - CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" + CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]" To build the C++ package run: diff --git a/python/mlx/extension.py b/python/mlx/extension.py index 8c0d60655..c426d5953 100644 --- a/python/mlx/extension.py +++ b/python/mlx/extension.py @@ -53,11 +53,7 @@ class CMakeBuild(build_ext): # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # across all generators. if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] + build_args += [f"-j{os.cpu_count()}"] build_temp = Path(self.build_temp) / ext.name if not build_temp.exists(): diff --git a/setup.py b/setup.py index 35f2e68ef..770718e25 100644 --- a/setup.py +++ b/setup.py @@ -97,11 +97,7 @@ class CMakeBuild(build_ext): # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level # across all generators. if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] + build_args += [f"-j{os.cpu_count()}"] build_temp = Path(self.build_temp) / ext.name if not build_temp.exists(): From 9d10239af7f7d1a0648af0b64ff6c393b85b2566 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 8 Jul 2025 00:44:14 +0900 Subject: [PATCH 131/156] [CUDA] Do vectorized store/load in binary ops (#2330) --- mlx/backend/cuda/binary.cu | 112 +++++++++++++++++++++++++----- mlx/backend/cuda/device/utils.cuh | 21 ++++++ 2 files changed, 116 insertions(+), 17 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index d9b9fd8af..0585dc76a 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -17,35 +17,106 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[0], b[0]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[0], b[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[0], b[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[0], b[offset]); + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a[0], b_vec.val[i]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[0]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[offset], b[0]); + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[offset], b[offset]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -198,16 +269,23 @@ void binary_op_gpu_inplace( } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 6e8abdd7c..89b609c45 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -28,6 +28,27 @@ namespace mlx::core::cu { using Shape = cuda::std::array; using Strides = cuda::std::array; +// Vectorized load/store. +template +struct alignas(sizeof(T) * N) AlignedVector { + T val[N]; +}; + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; +} + +template +inline __device__ void +store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; +} + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// From a4fcc893cd4caad05c97ed038e083b9c8395580c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 7 Jul 2025 09:29:23 -0700 Subject: [PATCH 132/156] auto build linux release (#2341) --- .circleci/config.yml | 10 ++++++++++ python/src/fast.cpp | 36 +++++++++++++++++++++++++++--------- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 293cdce79..be5f7aac5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -492,6 +492,16 @@ workflows: branches: ignore: /.*/ upload-docs: true + - build_linux_release: + filters: + tags: + only: /^v.*/ + branches: + ignore: /.*/ + matrix: + parameters: + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + extra_env: ["PYPI_RELEASE=1"] prb: when: diff --git a/python/src/fast.cpp b/python/src/fast.cpp index c94f99e1a..8adba2a25 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -175,11 +175,12 @@ void init_fast(nb::module_& parent_module) { * `Grouped Query Attention `_ * `Multi-Query Attention `_ - Note: The softmax operation is performed in ``float32`` regardless of - the input precision. + .. note:: - Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` - and ``v`` inputs should not be pre-tiled to match ``q``. + * The softmax operation is performed in ``float32`` regardless of + the input precision. + * For Grouped Query Attention and Multi-Query Attention, the ``k`` + and ``v`` inputs should not be pre-tiled to match ``q``. In the following the dimensions are given by: @@ -195,13 +196,30 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (Union[None, str, array], optional): A causal, boolean or additive - mask to apply to the query-key scores. The mask can have at most 4 - dimensions and must be broadcast-compatible with the shape - ``[B, N, T_q, T_kv]``. If an additive mask is given its type must - promote to the promoted type of ``q``, ``k``, and ``v``. + mask (Union[None, str, array], optional): The mask to apply to the + query-key scores. The mask can be an array or a string indicating + the mask type. The only supported string type is ``"causal"``. If + the mask is an array it can be a boolean or additive mask. The mask + can have at most 4 dimensions and must be broadcast-compatible with + the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its + type must promote to the promoted type of ``q``, ``k``, and ``v``. Returns: array: The output array. + + Example: + + .. code-block:: python + + B = 2 + N_q = N_kv = 32 + T_q = T_kv = 1000 + D = 128 + + q = mx.random.normal(shape=(B, N_q, T_q, D)) + k = mx.random.normal(shape=(B, N_kv, T_kv, D)) + v = mx.random.normal(shape=(B, N_kv, T_kv, D)) + scale = D ** -0.5 + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); m.def( From 4a9b29a8753ad65e2156bfe0d99d305fb48c4fcc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 7 Jul 2025 17:59:53 -0700 Subject: [PATCH 133/156] MoE backward improvements (#2335) --- mlx/backend/cpu/masked_mm.cpp | 170 +++++++++++ mlx/backend/cuda/primitives.cu | 1 + mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/indexing.cpp | 14 +- mlx/backend/metal/jit/includes.h | 1 + mlx/backend/metal/jit_kernels.cpp | 37 +++ mlx/backend/metal/kernels.h | 14 + mlx/backend/metal/kernels/CMakeLists.txt | 2 + .../steel/gemm/kernels/steel_gemm_segmented.h | 266 ++++++++++++++++++ .../gemm/kernels/steel_gemm_segmented.metal | 43 +++ mlx/backend/metal/matmul.cpp | 162 +++++++++++ mlx/backend/metal/nojit_kernels.cpp | 16 ++ mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/ops.cpp | 48 ++++ mlx/ops.h | 6 + mlx/primitives.cpp | 235 ++++++++++++---- mlx/primitives.h | 10 + python/src/ops.cpp | 22 ++ python/tests/cuda_skip.py | 4 + python/tests/test_blas.py | 93 ++++++ python/tests/test_quantized.py | 43 +++ 22 files changed, 1130 insertions(+), 60 deletions(-) create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 0be7c79ce..fbee6118f 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" @@ -52,6 +53,58 @@ inline void mask_matrix( } } +template +inline void segmented_mm( + const T* a, + const T* b, + const uint32_t* segments, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides, + size_t num_segments, + const Shape& segments_shape, + const Strides& segments_strides) { + int ndim = a_shape.size(); + Shape a_copy = a_shape; + Shape b_copy = b_shape; + int32_t M = a_copy[ndim - 2]; + int32_t N = b_copy[ndim - 1]; + for (int i = 0; i < num_segments; i++) { + uint32_t k_start = + segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; + uint32_t k_end = + segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; + if (k_end <= k_start) { + std::fill_n(out + i * M * N, M * N, T(0)); + continue; + } + a_copy[ndim - 1] = k_end - k_start; + b_copy[ndim - 2] = k_end - k_start; + matmul( + a + k_start * a_strides[ndim - 1], + b + k_start * b_strides[ndim - 2], + out + i * M * N, + a_transposed, + b_transposed, + lda, + ldb, + N, + 1.0, + 0.0, + 1, + a_copy, + a_strides, + b_copy, + b_strides); + } +} + } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { @@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { encoder.add_temporaries(std::move(temps)); } +void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cpu::get_command_encoder(stream()); + auto check_transpose = [&s, &encoder](const array& x) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + if (stx == x.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, x); + } else if (stx == 1 && sty == x.shape(-2)) { + return std::make_tuple(true, sty, x); + } else { + array xc(x.shape(), x.dtype(), nullptr, {}); + copy(x, xc, CopyType::General, s); + encoder.add_temporary(xc); + int64_t stx = x.shape(-1); + return std::make_tuple(false, stx, xc); + } + }; + + auto [a_transposed, lda, a] = check_transpose(inputs[0]); + auto [b_transposed, ldb, b] = check_transpose(inputs[1]); + auto& segments = inputs[2]; + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(segments); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + segments = array::unsafe_weak_copy(segments), + out_ptr = out.data(), + a_transposed = a_transposed, + b_transposed = b_transposed, + lda = lda, + ldb = ldb]() { + switch (a.dtype()) { + case float64: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float32: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case float16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + case bfloat16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size() / 2, + segments.shape(), + segments.strides()); + break; + default: + throw std::invalid_argument( + "Segmented mm supports only real float types."); + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 18fa45a33..a8496b958 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -83,6 +83,7 @@ NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) +NO_GPU(SegmentedMM) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d0c872451..ccdd83202 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -63,6 +63,7 @@ if(MLX_METAL_JIT) make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(steel/gemm/kernels/steel_gemm_gather) make_jit_source(steel/gemm/kernels/steel_gemm_splitk) + make_jit_source(steel/gemm/kernels/steel_gemm_segmented) make_jit_source( steel/conv/conv kernels/steel/utils.h diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a601b1e..13ce88a62 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -575,9 +575,17 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); - compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); - compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + if (ndim > 1) { + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); + } else { + // The following will be ignored in the kernel but we still have to set + // some value so that metal validation passes. + compute_encoder.set_vector_bytes(idx.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); + compute_encoder.set_vector_bytes(idx.strides(), 5); + } compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 27ae22d05..b380a8374 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -34,6 +34,7 @@ const char* steel_gemm_fused(); const char* steel_gemm_masked(); const char* steel_gemm_splitk(); const char* steel_gemm_gather(); +const char* steel_gemm_segmented(); const char* conv(); const char* steel_conv(); const char* steel_conv_general(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 467380c3a..fd0e0db09 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_segmented(), + get_template_definition( + lib_name, + "segmented_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1de5fa47c..794c67bdc 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int wn, bool rhs); +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ee88ca46..4069d8c21 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -71,6 +71,7 @@ set(STEEL_HEADERS steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_gather.h steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_segmented.h steel/gemm/kernels/steel_gemm_splitk.h steel/utils/type_traits.h steel/utils/integral_constant.h) @@ -120,6 +121,7 @@ if(NOT MLX_METAL_JIT) build_kernel(steel/gemm/kernels/steel_gemm_gather ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS}) build_kernel(gemv_masked steel/utils.h) endif() diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 000000000..b915eb343 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal new file mode 100644 index 000000000..a7515c359 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal @@ -0,0 +1,43 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h" + +#define instantiate_segmented_mm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_segmented_mm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + segmented_mm, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_segmented_mm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_segmented_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_segmented_mm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) +// clang-format on + +instantiate_segmented_mm_shapes_helper(float16, half, float16, half); +instantiate_segmented_mm_shapes_helper( + bfloat16, + bfloat16_t, + bfloat16, + bfloat16_t); +instantiate_segmented_mm_shapes_helper(float32, float, float32, float); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index be7f3e2f8..55b8be3a9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1864,4 +1864,166 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } +void segmented_mm( + const array& a_, + const array& b_, + const array& segments_, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + auto check_segments_layout = [&d, &s](const array& x) { + // Contiguous so return early + if (x.flags().row_contiguous) { + return std::make_tuple(true, x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 2; i++) { + rc &= + (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1); + } + rc &= x.strides(x.ndim() - 1) == 1; + if (x.ndim() > 1) { + rc &= x.strides(x.ndim() - 2) == 1; + } + + if (rc) { + return std::make_tuple(false, x); + } + + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return std::make_tuple(true, x_copy); + }; + + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + auto [segments_contiguous, segments] = check_segments_layout(segments_); + d.add_temporaries(std::move(copies), s.index); + + // Determine dispatch kernel + int bm = 64, bn = 64, bk = 16; + int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_segmented_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&segments_contiguous, MTL::DataType::DataTypeBool, 199}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_segments_contiguous_", + segments_contiguous ? 't' : 'n', + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_segmented_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ 0, + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(segments, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& segments = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + segmented_mm(a, b, segments, out, M, N, K, d, s); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b0375e37f..32d3e75f7 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 1a180bfe0..09e6c4ef3 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -105,6 +105,7 @@ NO_CPU(Scan) NO_CPU(Scatter) NO_CPU(ScatterAxis) NO_CPU(Select) +NO_CPU(SegmentedMM) NO_CPU(Sigmoid) NO_CPU(Sign) NO_CPU(Sin) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 409aa2c89..dfe5b57f1 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -121,6 +121,7 @@ NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) +NO_GPU(SegmentedMM) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2b861428f..7161a39b2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4649,6 +4649,54 @@ array gather_mm( return axes.empty() ? out : squeeze(out, axes, s); } +array segmented_mm( + array a, + array b, + array segments, + StreamOrDevice s /* = {} */) { + if (a.ndim() != 2 || b.ndim() != 2) { + throw std::invalid_argument("[segmented_mm] Batched matmul not supported"); + } + + if (segments.ndim() < 1 || segments.shape().back() != 2) { + std::ostringstream msg; + msg << "[segmented_mm] The segments should have shape (..., 2) but " + << segments.shape() << " was provided."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[segmented_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + if (!issubdtype(segments.dtype(), integer)) { + throw std::invalid_argument( + "[segmented_mm] Got segments with invalid dtype. Segments must be integral."); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + segments = astype(segments, uint32, s); + + Shape out_shape = segments.shape(); + out_shape.pop_back(); + out_shape.push_back(a.shape(0)); + out_shape.push_back(b.shape(1)); + + return array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s)), + {std::move(a), std::move(b), std::move(segments)}); +} + array diagonal( const array& a, int offset /* = 0 */, diff --git a/mlx/ops.h b/mlx/ops.h index af3cdb5bd..596d6d287 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1406,6 +1406,12 @@ array gather_mm( bool sorted_indices = false, StreamOrDevice s = {}); +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +array segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5f2bfdda4..b2b7306dd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -109,6 +109,70 @@ std::tuple vmap_ternary_op( return {a, b, c, to_ax}; } +// Calculate the gradient wrt to the weights of the following calculation +// +// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted) +// +// Note the transpose above. This function returns the gradient for w.T so if w +// was used instead then one needs to transpose the returned gradient. +// +// We define it as a separate function to reuse it for gather_mm and +// gather_qmm. +array gather_mm_grad( + const array& x, + const array& dy, + const array& lhs_indices, + const array& rhs_indices, + bool sorted, + Shape batch_shape, + const Stream& s) { + int M = x.shape(-2); + int K = x.shape(-1); + int N = dy.shape(-1); + int num_segments = std::accumulate( + batch_shape.begin(), batch_shape.end(), 1, std::multiplies()); + batch_shape.push_back(N); + batch_shape.push_back(K); + + // If the indices are sorted then it means that we can do the whole gradient + // computation via a segmented matmul. We just need to calculate the segments + // using the indices. + if (sorted) { + auto segments = zeros({num_segments}, uint32, s); + segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s); + segments = cumsum(segments, 0, false, true, s); + segments = concatenate({array({0}, {1}, uint32), segments}, 0, s); + segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s); + + return reshape( + segmented_mm( + swapaxes(flatten(dy, 0, -2, s), 0, 1, s), + flatten(x, 0, -2, s), + segments, + s), + std::move(batch_shape), + s); + } + + // Otherwise we need to gather matmul the dy and then scatter add it to the + // correct locations. + else { + // TODO: If the lhs indices wasn't provided, this is always a sorted matmul + // so we should add that check. + auto dw = gather_mm( + swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s); + return reshape( + scatter_add( + zeros({num_segments, N, K}, dw.dtype(), s), + rhs_indices, + expand_dims(dw, -3, s), + 0, + s), + std::move(batch_shape), + s); + } +} + } // namespace std::vector Primitive::jvp( @@ -3169,8 +3233,9 @@ std::vector QuantizedMatmul::vjp( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { if (!dsb) { - auto fc = flatten(cotangents[0], 0, -2, stream()); - auto fx = flatten(primals[0], 0, -2, stream()); + int ndim = primals[1].ndim(); + auto fc = flatten(cotangents[0], 0, -ndim, stream()); + auto fx = flatten(primals[0], 0, -ndim, stream()); auto dw = transpose_ ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); @@ -3181,7 +3246,6 @@ std::vector QuantizedMatmul::vjp( vjps.push_back(sum(*dsb, -1, false, stream())); } else { // scales - auto s = stream(); auto wq = dequantize( primals[1], ones_like(primals[2], stream()), @@ -3253,34 +3317,42 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + int M = cotan.shape(-2); + int N = cotan.shape(-1); + int K = x.shape(-1); + bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == x.size(); + std::optional dsb = std::nullopt; for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { - vjps.push_back(reshape( - scatter_add( - flatten(zeros_like(x, stream()), 0, -3, stream()), - lhs_indices, - expand_dims( - gather_qmm( - cotan, - w, - scales, - biases, - std::nullopt, - rhs_indices, - !transpose_, - group_size_, - bits_, - sorted, - stream()), - -3, - stream()), - 0, - stream()), - x.shape(), - stream())); + auto g = gather_qmm( + cotan, + w, + scales, + biases, + std::nullopt, + rhs_indices, + !transpose_, + group_size_, + bits_, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(x, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + x.shape(), + stream())); + } } // gradient wrt to the indices is undefined @@ -3290,9 +3362,49 @@ std::vector GatherQMM::vjp( } // gradient wrt to w_q, scales or biases - else { + else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized matrix yet."); + "GatherQMM::vjp no gradient wrt the quantized weights."); + } else { + if (!dsb) { + auto shape = w.shape(); + shape.pop_back(); + shape.pop_back(); + dsb = unflatten( + gather_mm_grad( + x, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + {-1, group_size_}, + stream()); + } + if (arg == 3) { + vjps.push_back(sum(*dsb, -1, false, stream())); + } else { + vjps.push_back( + sum(multiply( + *dsb, + unflatten( + dequantize( + w, + ones_like(scales, stream()), + zeros_like(biases, stream()), + group_size_, + bits_, + stream()), + -1, + {-1, group_size_}, + stream()), + stream()), + -1, + false, + stream())); + } } } return vjps; @@ -5064,6 +5176,8 @@ std::vector GatherMM::vjp( std::vector vjps; auto& cotan = cotangents[0]; + auto& a = primals[0]; + auto& b = primals[1]; auto& lhs_indices = primals[2]; auto& rhs_indices = primals[3]; @@ -5072,39 +5186,46 @@ std::vector GatherMM::vjp( int K = primals[0].shape(-1); bool sorted = left_sorted_ || right_sorted_; + bool no_broadcast = rhs_indices.size() * M * K == primals[0].size(); for (auto arg : argnums) { if (arg == 0) { - // M X N * (K X N).T -> M X K - auto base = zeros_like(primals[0], stream()); - auto bt = swapaxes(primals[1], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, M, K}, stream()); - - // g : (out_batch_shape) + (M, K) - auto g = - gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); - + auto g = gather_mm( + cotan, + swapaxes(b, -1, -2, stream()), + std::nullopt, + rhs_indices, + sorted, + stream()); + if (sorted && no_broadcast) { + vjps.push_back(g); + } else { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(a, stream()), 0, -3, stream()), + lhs_indices, + expand_dims(g, -3, stream()), + 0, + stream()), + a.shape(), + stream())); + } } else if (arg == 1) { - // (M X K).T * M X N -> K X N - auto base = zeros_like(primals[1], stream()); - auto at = swapaxes(primals[0], -1, -2, stream()); - - auto base_shape = base.shape(); - base = reshape(base, {-1, K, N}, stream()); - - // g : (out_batch_shape) + (K, N) - auto g = - gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream()); - g = expand_dims(g, -3, stream()); - auto gacc = scatter_add(base, rhs_indices, g, 0, stream()); - - vjps.push_back(reshape(gacc, base_shape, stream())); + auto shape = b.shape(); + shape.pop_back(); + shape.pop_back(); + vjps.push_back(swapaxes( + gather_mm_grad( + a, + cotan, + lhs_indices, + rhs_indices, + sorted, + std::move(shape), + stream()), + -1, + -2, + stream())); } else { throw std::invalid_argument( "[GatherMM] Cannot calculate VJP with respect to indices."); diff --git a/mlx/primitives.h b/mlx/primitives.h index 4b18430ca..f4f157298 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive { bool right_sorted_; }; +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(SegmentedMM) +}; + class BroadcastAxes : public UnaryPrimitive { public: explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a1e77d681..d047f64cb 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "segmented_mm", + &mx::segmented_mm, + nb::arg(), + nb::arg(), + "segments"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform a matrix multiplication but segment the inner dimension and + save the result for each segment separately. + + Args: + a (array): Input array of shape ``MxK``. + b (array): Input array of shape ``KxN``. + segments (array): The offsets into the inner dimension for each segment. + + Returns: + array: The result per segment of shape ``MxN``. + )pbdoc"); m.def( "tensordot", [](const mx::array& a, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index fce92bacb..17eb80eee 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -8,6 +8,9 @@ cuda_skip = { # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", + "TestBlas.test_gather_mm_sorted", + # Segmented matmul NYI + "TestBlas.test_segmented_mm", # Scan NYI "TestArray.test_api", "TestAutograd.test_cumprod_grad", @@ -76,6 +79,7 @@ cuda_skip = { "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_gather_qmm_grad", "TestQuantized.test_non_multiples", "TestQuantized.test_qmm", "TestQuantized.test_qmm_jvp", diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index eb45df124..5e096d9c5 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1163,6 +1163,99 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gather_mm_sorted(self): + def gather_mm_ref(a, b, rhs): + b = b[rhs] + return a @ b + + def gather_mm_test(a, b, rhs): + return mx.gather_mm(a, b, rhs_indices=rhs, sorted_indices=True) + + a = mx.random.normal((100, 1, 100)) + b = mx.random.normal((8, 100, 100)) + rhs = mx.sort(mx.random.randint(0, 8, shape=(100,))) + + c1 = gather_mm_ref(a, b, rhs) + c2 = gather_mm_test(a, b, rhs) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + cotan = mx.random.normal(c1.shape) + c1, dc1 = mx.vjp( + lambda a, b: gather_mm_ref(a, b, rhs), + [a, b], + [cotan], + ) + c2, dc2 = mx.vjp( + lambda a, b: gather_mm_test(a, b, rhs), + [a, b], + [cotan], + ) + self.assertTrue(mx.allclose(c1[0], c2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[0], dc2[0], atol=1e-4)) + self.assertTrue(mx.allclose(dc1[1], dc2[1], atol=1e-4)) + + def test_segmented_mm(self): + def segmented_mm_ref(a, b, s): + s = s.tolist() + c = [] + for s1, s2 in s: + c.append(a[:, s1:s2] @ b[s1:s2, :]) + return mx.stack(c, axis=0) + + shapes = [ + (10, 10, 10), + (10, 10, 1000), + (1000, 1000, 1000), + ] + all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]] + + for M, N, K in shapes: + for s in all_segments: + segments = [] + for i in range(len(s) - 1): + segments.append([s[i], s[i + 1]]) + segments = mx.array(segments) + segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32)) + a = mx.random.normal((M, K)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a, b, segments) + c2 = mx.segmented_mm(a, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((K, N)) + c1 = segmented_mm_ref(a.T, b, segments) + c2 = mx.segmented_mm(a.T, b, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((M, K)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a, b.T, segments) + c2 = mx.segmented_mm(a, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + a = mx.random.normal((K, M)) + b = mx.random.normal((N, K)) + c1 = segmented_mm_ref(a.T, b.T, segments) + c2 = mx.segmented_mm(a.T, b.T, segments) + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + with self.assertRaises(ValueError): + a = mx.ones((2, 10, 10)) + s = mx.array([[0, 5], [5, 10]]).astype(mx.uint32) + mx.segmented_mm(a, a, s) + + a = mx.ones((10, 1000)) + s = mx.random.randint(0, 16, shape=(1000,)) + s = mx.zeros(16, dtype=s.dtype).at[s].add(1) + s = mx.sort(s) + s = mx.cumsum(s) + s = mx.concatenate([mx.array([0]), s]) + s = mx.as_strided(s, (16, 2), (1, 1)) + s = mx.reshape(s, (2, 2, 4, 2)) + c = mx.segmented_mm(a, a.T, s) + self.assertEqual(c.shape, (2, 2, 4, 10, 10)) + def test_gemv_gemm_same_precision(self): mx.random.seed(0) N = 256 diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f402bd444..2c62c6307 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_gather_qmm_grad(self): + def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): + if lhs is not None: + x = x[lhs] + if rhs is not None: + w = w[rhs] + s = s[rhs] + b = b[rhs] + return mx.quantized_matmul(x, w, s, b, transpose=trans) + + def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): + return mx.gather_qmm( + x, + w, + s, + b, + transpose=trans, + lhs_indices=lhs, + rhs_indices=rhs, + sorted_indices=sort, + ) + + x = mx.random.normal((16, 1, 256)) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) + cotan = mx.random.normal((16, 1, 256)) + + (o1,), (dx1, ds1, db1) = mx.vjp( + lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + (o2,), (dx2, ds2, db2) = mx.vjp( + lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + + self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) + self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) + self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3)) + self.assertTrue(mx.allclose(db1, db2, atol=1e-3)) + def test_vjp_scales_biases(self): mx.random.seed(0) x = mx.random.normal(shape=(2, 2, 512)) From 2ca533b27943f3ed916b209ace38fad6bfe89def Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 8 Jul 2025 12:00:43 +0900 Subject: [PATCH 134/156] Fix compilation with CUDA 11 (#2331) --- mlx/backend/cuda/arg_reduce.cu | 1 + mlx/backend/cuda/device.cpp | 25 +++++---- mlx/backend/cuda/device/cast_op.cuh | 69 +++++++++++++++++++++++- mlx/backend/cuda/device/utils.cuh | 12 ++--- mlx/backend/cuda/reduce/all_reduce.cu | 6 +-- mlx/backend/cuda/reduce/col_reduce.cu | 9 ++-- mlx/backend/cuda/reduce/reduce_ops.cuh | 12 +++-- mlx/backend/cuda/reduce/reduce_utils.cuh | 16 ------ mlx/backend/cuda/reduce/row_reduce.cu | 15 +++--- mlx/backend/cuda/rms_norm.cu | 4 +- mlx/backend/cuda/softmax.cu | 2 +- 11 files changed, 115 insertions(+), 56 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index ad942a406..67ef5d968 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 638d68727..99ccfdb4a 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -264,19 +264,26 @@ void CommandEncoder::commit() { graph_key_ += std::to_string(graph_node_count_); graph_key_ += "."; graph_key_ += std::to_string(empty_node_count_); - auto [it, _] = graph_cache_.emplace(graph_key_, nullptr); - auto& graph_exec = it->second; - if (graph_exec != NULL) { - cudaGraphExecUpdateResultInfo update_result; - cudaGraphExecUpdate(graph_exec, graph_, &update_result); - if (update_result.result != cudaGraphExecUpdateSuccess) { - cudaGetLastError(); + cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; + + if (graph_exec != nullptr) { + cudaGraphExecUpdateResult update_result; +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo info; + cudaGraphExecUpdate(graph_exec, graph_, &info); + update_result = info.result; +#else + cudaGraphNode_t error_node; + cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); +#endif // CUDART_VERSION >= 12000 + if (update_result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); // reset error CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); - graph_exec = NULL; + graph_exec = nullptr; } } - if (graph_exec == NULL) { + if (graph_exec == nullptr) { CHECK_CUDA_ERROR( cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); } diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index f15270432..8da19ddf8 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include namespace mlx::core::cu { @@ -17,6 +19,26 @@ struct CastOp { } }; +// Castings between complex and boolean. +// TODO: Should make a custom complex type. +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ bool operator()(cuComplex x) { + return x.x != 0 && x.y != 0; + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ cuComplex operator()(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); + } +}; + // Converting a complex number to real number discards the imaginary part. template struct CastOp< @@ -45,6 +67,7 @@ struct CastOp< } }; +// Do nothing when no casting is needed. template struct CastOp< SrcT, @@ -57,9 +80,53 @@ struct CastOp< } }; +// In CUDA 11 the half types do not define conversions between some types, +// provide fallbacks here. +#if CUDART_VERSION < 12000 +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && + !cuda::std::is_same_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; + +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && + !cuda::std::is_same_v && + !cuda::std::is_same_v && + !cuda::std::is_same_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; +#endif // CUDART_VERSION < 12000 + +// Helper to deduce the SrcT. +template +inline __host__ __device__ auto cast_to(SrcT x) { + return CastOp{}(x); +} + // Return an iterator that cast the value to DstT using CastOp. template -__host__ __device__ auto make_cast_iterator(Iterator it) { +inline __host__ __device__ auto make_cast_iterator(Iterator it) { using SrcT = typename cuda::std::iterator_traits::value_type; if constexpr (std::is_same_v) { return it; diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 89b609c45..83e149165 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -99,20 +99,20 @@ struct Limits< return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { -#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 - return -cuda::std::numeric_limits::infinity(); -#else +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return -cuda::std::numeric_limits::infinity(); +#else + return -cuda::std::numeric_limits::infinity(); #endif } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { -#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 - return cuda::std::numeric_limits::lowest(); -#else +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return cuda::std::numeric_limits::lowest(); +#else + return cuda::std::numeric_limits::lowest(); #endif } }; diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 3419d61cb..166a11a79 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], __cast(vals[j])); + accs[0] = op(accs[0], cast_to(vals[j])); } } if (i < check) { cub::LoadDirectBlocked( - block.thread_rank(), in + i, vals, check - i, __cast(init)); + block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], __cast(vals[i])); + accs[0] = op(accs[0], cast_to(vals[i])); } } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 910fa0379..fec5ca76b 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -3,7 +3,6 @@ #include #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include @@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { T vals[N_READS]; cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } @@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { T vals[N_READS]; cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } @@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { in + loop.location(), vals, args.reduction_stride - tile_x * BN, - __cast(ReduceInit::value())); + cast_to(ReduceInit::value())); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index b40d2bd4e..bc4dce33e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/cuda/device/atomic_ops.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh" @@ -40,15 +42,15 @@ struct Sum { } __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { - atomicAdd(x, y); + atomic_add(x, y); } __device__ void atomic_update(int* x, int y) { - atomicAdd(x, y); + atomic_add(x, y); } __device__ void atomic_update(float* x, float y) { - atomicAdd(x, y); + atomic_add(x, y); } }; @@ -152,7 +154,7 @@ struct ReduceInit { if constexpr (cuda::std::is_same_v) { return T{0, 0}; } else { - return typename ReduceResult::type{0}; + return cast_to::type>(0); } } }; @@ -163,7 +165,7 @@ struct ReduceInit { if constexpr (cuda::std::is_same_v) { return T{1, 0}; } else { - return typename ReduceResult::type{1}; + return cast_to::type>(1); } } }; diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index d4670503a..ccd7ae48d 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -55,22 +55,6 @@ __device__ void atomic_reduce(T* x, T y) { } } -// TODO: Should make a custom complex type -template -inline __device__ U __cast(T x) { - return static_cast(x); -} - -template <> -inline __device__ bool __cast(cuComplex x) { - return x.x != 0 && x.y != 0; -} - -template <> -inline __device__ cuComplex __cast(bool x) { - return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); -} - template inline __device__ void block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index e57f18668..61838ddd3 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -3,7 +3,6 @@ #include #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include @@ -113,7 +112,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + r * (block.size() * N), vals[k]); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -125,7 +124,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + r * (block.size() * N), vals[k]); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -138,9 +137,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + final_offset, vals[k], size, - __cast(init)); + cast_to(init)); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -199,7 +198,7 @@ __global__ void row_reduce_looped( in + loop.location() + r * BLOCK_DIM * N_READS, vals); for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); + total[0] = op(total[0], cast_to(vals[i])); } } if (final_offset < args.row_size) { @@ -209,9 +208,9 @@ __global__ void row_reduce_looped( in + loop.location() + final_offset, vals, args.row_size - final_offset, - __cast(init)); + cast_to(init)); for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); + total[0] = op(total[0], cast_to(vals[i])); } } // TODO: Maybe block.sync() here? diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 5ee1d3386..964bd7d98 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -74,7 +74,7 @@ __global__ void rms_norm( for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); T xn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]); normalizer += t * t; @@ -130,7 +130,7 @@ __global__ void rms_norm_vjp( T wn[N_READS] = {}; T gn[N_READS] = {}; auto index = r * BLOCK_DIM + block.thread_rank(); - cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index fd807bd8d..56f67d7f3 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { // Thread reduce. AccT prevmax; AccT maxval = Limits::finite_min(); - AccT normalizer = 0; + AccT normalizer = cast_to(0); for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { AccT vals[N_READS]; cub::LoadDirectBlocked( From fb4e8b896b7a7dd73450c71f12a8818895909fe7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 8 Jul 2025 14:26:07 -0700 Subject: [PATCH 135/156] patch bump (#2343) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index 5ad66e3c2..c01135177 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 26 -#define MLX_VERSION_PATCH 2 +#define MLX_VERSION_PATCH 3 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From 8b9a3f3ceab86a6e79dde498b20c0142328bf4a4 Mon Sep 17 00:00:00 2001 From: jhavukainen <104022140+jhavukainen@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:26:27 -0700 Subject: [PATCH 136/156] Align mlx::core::max op nan propagation with NumPy (#2339) * Make max op NaN propagation rules align with numpy * Adding benchmarks and testing for max op nanpropagation * Pre-commit formatting * Fix max complex64 nan propagation and add test * Improve the cpp unittest * Only check nans on non-integral types in simd_reduce_impl. * Cleanup using namespace alias * Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16. * Make the max nanpropagation test more meaningful for integer types * Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR. --- benchmarks/cpp/single_ops.cpp | 11 +++++ benchmarks/python/single_ops.py | 8 ++++ mlx/backend/cpu/reduce.cpp | 14 ++++-- mlx/backend/metal/kernels/reduction/ops.h | 40 +++++++++++++++- python/tests/cuda_skip.py | 2 + python/tests/test_reduce.py | 57 +++++++++++++++++++++++ tests/ops_tests.cpp | 4 ++ 7 files changed, 131 insertions(+), 5 deletions(-) diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 5b327be58..6eac366bc 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -192,6 +192,17 @@ void time_reductions() { auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; TIME(argmin_along_1); + + auto indices = mx::array({1}); + auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1}); + std::vector axes{0}; + auto b = scatter(a, {indices}, updates, axes); + mx::eval(b); + + auto max_along_0 = [&b]() { return mx::max(b, 0, false); }; + TIME(max_along_0); + auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; + TIME(max_along_1); } void time_gather_scatter() { diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 3160a1833..5d2906fe7 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -51,6 +51,13 @@ def time_maximum(): time_fn(mx.maximum, a, b) +def time_max(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.max, a, 0) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -108,6 +115,7 @@ if __name__ == "__main__": time_add() time_matmul() + time_max() time_maximum() time_exp() time_negative() diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index ce25feb11..87e3aa857 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -325,7 +325,15 @@ struct MaxReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::max(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::max(x); }; }; @@ -527,10 +535,10 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); + reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 68ed11986..57ddffef8 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -186,7 +186,15 @@ struct Max { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_max(val); } @@ -198,7 +206,35 @@ struct Max { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } }; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 17eb80eee..afd48bd03 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -3,6 +3,8 @@ cuda_skip = { "TestLayers.test_quantized_embedding", "TestOps.test_dynamic_slicing", "TestReduce.test_dtypes", + "TestReduce.test_nanpropagation", + "TestReduce.test_nanpropagation_complex64", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 2b899c099..d757f1527 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -153,6 +153,63 @@ class TestReduce(mlx_tests.MLXTestCase): x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) check(x, (1, 3, 5, 7, 9)) + def test_nanpropagation(self): + dtypes = [ + "uint8", + "uint16", + "uint32", + "int8", + "int16", + "int32", + "float16", + "float32", + ] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + x = (mx.random.normal((4, 4)) * 10).astype(getattr(mx, dtype)) + indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2) + for idx in indices: + x[idx[0], idx[1]] = mx.nan + x_np = np.array(x) + + for op in ["max"]: + for axis in [0, 1]: + out = getattr(mx, op)(x, axis=axis) + ref = getattr(np, op)(x_np, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + + def test_nanpropagation_complex64(self): + complex_array_1 = mx.array( + [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_2 = mx.array( + [1 + 1j, 2 + 2j, 3 + mx.nan * 1j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_3 = mx.array( + [1 + 1j, 2 + mx.nan * 1j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + complex_array_4 = mx.array( + [mx.nan + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=mx.complex64 + ).reshape(2, 2) + + np_arrays = [ + np.array(complex_array_1), + np.array(complex_array_2), + np.array(complex_array_3), + np.array(complex_array_4), + ] + + for mx_arr, np_arr in zip( + [complex_array_1, complex_array_2, complex_array_3, complex_array_4], + np_arrays, + ): + for axis in [0, 1]: + for op in ["max"]: + out = getattr(mx, op)(mx_arr, axis=axis) + ref = getattr(np, op)(np_arr, axis=axis) + self.assertTrue(np.array_equal(out, ref, equal_nan=True)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 8833424a6..1a9781c7c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1024,6 +1024,10 @@ TEST_CASE("test reduction ops") { x = array({true, true, true, false, true, false}, {2, 3}); CHECK(array_equal(min(x, 1), array({true, false})).item()); CHECK(array_equal(min(x, 0), array({false, true, false})).item()); + + x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item()); + CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item()); } // Test logsumexp From e14ee124917e824a25a4cd3803197de4f31697fc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 9 Jul 2025 14:37:14 -0700 Subject: [PATCH 137/156] add zero for argsort vjp (#2345) --- mlx/primitives.cpp | 20 ++++++++++++++++++-- mlx/primitives.h | 1 + 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b2b7306dd..eb5d9d6b3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -620,10 +620,11 @@ std::vector ArgReduce::vjp( } std::vector ArgReduce::jvp( + const std::vector& primals, const std::vector&, - const std::vector& tangents, const std::vector&) { - return {zeros_like(tangents[0], stream())}; + auto shape = output_shapes(primals)[0]; + return {zeros(shape, uint32, stream())}; } std::pair, std::vector> ArgSort::vmap( @@ -647,6 +648,21 @@ bool ArgSort::is_equivalent(const Primitive& other) const { return axis_ == r_other.axis_; } +std::vector ArgSort::vjp( + const std::vector& primals, + const std::vector&, + const std::vector&, + const std::vector&) { + return {zeros_like(primals[0], stream())}; +} + +std::vector ArgSort::jvp( + const std::vector& primals, + const std::vector&, + const std::vector&) { + return {zeros(primals[0].shape(), uint32, stream())}; +} + std::vector AsType::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index f4f157298..3d3202aaa 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -378,6 +378,7 @@ class ArgSort : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() + DEFINE_GRADS() DEFINE_PRINT(ArgSort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; From 85873cb162d0802c925350cbc68e1410bce3f1ad Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 10 Jul 2025 10:48:43 +0900 Subject: [PATCH 138/156] [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342) * Do vectorized store/load in unary ops * Do vectorized store/load in binary_two ops * Do vectorized store/load in copy ops * Do vectorized store/load in ternary ops * Use int32_t for IdxT * binary => binary_two in binary_two.cu * Fix tests on large arrays * Use uint as index type * Contig uses uint as index and non-contig uses int --- mlx/backend/cuda/binary.cu | 46 ++----- mlx/backend/cuda/binary_two.cu | 156 +++++++++++++++++------ mlx/backend/cuda/copy/copy_contiguous.cu | 49 +++++-- mlx/backend/cuda/ternary.cu | 34 ++++- mlx/backend/cuda/unary.cu | 34 ++++- 5 files changed, 223 insertions(+), 96 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 0585dc76a..fc5b8c496 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -20,15 +20,10 @@ namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[0]); + if ((index + 1) * N_READS > size) { + for (int i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[0]); } } else { AlignedVector out_vec; @@ -44,15 +39,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[i]); } } else { auto b_vec = load_vector(b, index); @@ -70,15 +60,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[0]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[0]); } } else { auto a_vec = load_vector(a, index); @@ -96,15 +81,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i]); } } else { auto a_vec = load_vector(a, index); @@ -267,7 +247,7 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 9582b0378..4b6e24581 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -17,52 +17,119 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void -binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[0]); - out_a[0] = out[0]; - out_b[0] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[0]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void -binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } template -__global__ void binary_g_nd( +__global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, @@ -82,7 +149,7 @@ __global__ void binary_g_nd( } template -__global__ void binary_g( +__global__ void binary_two_g( const In* a, const In* b, Out* out_a, @@ -103,7 +170,7 @@ __global__ void binary_g( } template -constexpr bool supports_binary_op() { +constexpr bool supports_binary_two_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); @@ -114,7 +181,7 @@ constexpr bool supports_binary_op() { } // namespace cu template -void binary_op_gpu_inplace( +void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -141,7 +208,7 @@ void binary_op_gpu_inplace( dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { + if constexpr (cu::supports_binary_two_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; @@ -161,8 +228,12 @@ void binary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu:: - binary_g_nd; + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -179,7 +250,7 @@ void binary_op_gpu_inplace( const_param(b_strides)); }); } else { - auto kernel = cu::binary_g; + auto kernel = cu::binary_two_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -198,22 +269,25 @@ void binary_op_gpu_inplace( } }); } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_two_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_two_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( kernel, out_a.data_size(), out_a.shape(), out_a.strides(), - large()); + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, @@ -237,7 +311,7 @@ void binary_op_gpu_inplace( } template -void binary_op_gpu( +void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -247,7 +321,7 @@ void binary_op_gpu( auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); - binary_op_gpu_inplace(inputs, outputs, op, s); + binary_two_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( @@ -255,7 +329,7 @@ void DivMod::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); - binary_op_gpu(inputs, outputs, get_primitive_string(this), s); + binary_two_op_gpu(inputs, outputs, get_primitive_string(this), s); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 408350129..4e9eaccb7 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -10,19 +10,43 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[0]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[0]); + } + } else { + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in[0]); + } + + store_vector(out, index, out_vec); } } -template +template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = CastOp{}(in[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = cast_to(in_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -41,12 +65,19 @@ void copy_contiguous( using InType = cuda_type_t; using OutType = cuda_type_t; using IdxT = std::conditional_t; - auto kernel = cu::copy_s; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { - kernel = cu::copy_v; + kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index aa6523f27..eb69442c2 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -15,12 +15,27 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[index], c[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i], c[i]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + auto c_vec = load_vector(c, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -149,11 +164,18 @@ void ternary_op_gpu_inplace( } }); } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 3f1a62d24..1fe1b557b 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -18,11 +18,24 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(in[index]); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = Op{}(in_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -112,14 +125,20 @@ void unary_op_gpu_inplace( using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { dispatch_bool(large, [&](auto large) { - using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; if (contig) { - auto kernel = cu::unary_v; + using IdxT = std::conditional_t; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::unary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large, + N_READS); encoder.add_kernel_node( kernel, num_blocks, @@ -128,6 +147,7 @@ void unary_op_gpu_inplace( out.data(), out.data_size()); } else { + using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); auto kernel = cu::unary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); From 8c7bc30ce4739d5d40ca7ce59147a6c71016bc4b Mon Sep 17 00:00:00 2001 From: jhavukainen <104022140+jhavukainen@users.noreply.github.com> Date: Thu, 10 Jul 2025 06:20:43 -0700 Subject: [PATCH 139/156] Align mlx::core::min op nan propagation with NumPy (#2346) --- benchmarks/cpp/single_ops.cpp | 5 +++ benchmarks/python/single_ops.py | 8 +++++ mlx/backend/cpu/reduce.cpp | 10 +++++- mlx/backend/metal/kernels/reduction/ops.h | 41 +++++++++++++++++++++-- python/tests/test_reduce.py | 4 +-- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 6eac366bc..1f93a78d7 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -203,6 +203,11 @@ void time_reductions() { TIME(max_along_0); auto max_along_1 = [&b]() { return mx::max(b, 1, false); }; TIME(max_along_1); + + auto min_along_0 = [&b]() { return mx::min(b, 0, false); }; + TIME(min_along_0); + auto min_along_1 = [&b]() { return mx::min(b, 1, false); }; + TIME(min_along_1); } void time_gather_scatter() { diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py index 5d2906fe7..939faf305 100644 --- a/benchmarks/python/single_ops.py +++ b/benchmarks/python/single_ops.py @@ -58,6 +58,13 @@ def time_max(): time_fn(mx.max, a, 0) +def time_min(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + a[1, 1] = mx.nan + mx.eval(a) + time_fn(mx.min, a, 0) + + def time_negative(): a = mx.random.uniform(shape=(10000, 1000)) mx.eval(a) @@ -115,6 +122,7 @@ if __name__ == "__main__": time_add() time_matmul() + time_min() time_max() time_maximum() time_exp() diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 87e3aa857..8febbd050 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -350,7 +350,15 @@ struct MinReduce { }; template - T operator()(simd::Simd x) { + std::enable_if_t, T> operator()(simd::Simd x) { + return simd::min(x); + }; + + template + std::enable_if_t, T> operator()(simd::Simd x) { + if (simd::any(x != x)) { + return static_cast(NAN); + } return simd::min(x); }; }; diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 57ddffef8..11d8e83ac 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -164,7 +164,15 @@ struct Min { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_min(val); } @@ -176,11 +184,38 @@ struct Min { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } -}; + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; template struct Max { DEFINE_SIMD_REDUCE() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index d757f1527..9efd6c5c7 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase): x[idx[0], idx[1]] = mx.nan x_np = np.array(x) - for op in ["max"]: + for op in ["max", "min"]: for axis in [0, 1]: out = getattr(mx, op)(x, axis=axis) ref = getattr(np, op)(x_np, axis=axis) @@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase): np_arrays, ): for axis in [0, 1]: - for op in ["max"]: + for op in ["max", "min"]: out = getattr(mx, op)(mx_arr, axis=axis) ref = getattr(np, op)(np_arr, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True)) From 8fb3e7a26c35d768aceac6eb20a4ebf13740b8b2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 10 Jul 2025 23:24:02 +0900 Subject: [PATCH 140/156] [CUDA] Set current device before cudaGraphLaunch (#2351) --- mlx/backend/cuda/device.cpp | 19 ++++++++++--------- mlx/backend/cuda/device.h | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 99ccfdb4a..f7c8ecdc0 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -57,6 +57,14 @@ void Device::make_current() { } } +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; + } + return it->second; +} + CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); CHECK_CUDA_ERROR( @@ -168,15 +176,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } -CommandEncoder& Device::get_command_encoder(Stream s) { - auto it = encoders_.find(s.index); - if (it == encoders_.end()) { - it = encoders_.try_emplace(s.index, *this).first; - } - return it->second; -} - -CommandEncoder::CommandEncoder(Device& d) : stream_(d) { +CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); } @@ -287,6 +287,7 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR( cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); } + device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); // TODO smarter cache policy diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 4ebdae55c..8ac840cbb 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -93,6 +93,7 @@ class CommandEncoder { void insert_graph_dependencies(GraphNode node); void insert_graph_dependencies(std::vector nodes); + Device& device_; CudaStream stream_; cudaGraph_t graph_; Worker worker_; From afb9817599ac8b3c0399274ebd7bc6d30dbb30b8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 10 Jul 2025 23:24:21 +0900 Subject: [PATCH 141/156] [CUDA] Put version in ptx cache dir path (#2352) --- mlx/backend/cuda/jit_module.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 5bc56b25e..e6dbd35da 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/device.h" +#include "mlx/version.h" #include "cuda_jit_sources.h" @@ -53,10 +54,11 @@ const std::string& cuda_home() { const std::filesystem::path& ptx_cache_dir() { static std::filesystem::path cache = []() -> std::filesystem::path { std::filesystem::path cache; - if (auto c = std::getenv("MLX_PTX_CACHE"); c) { + if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) { cache = c; } else { - cache = std::filesystem::temp_directory_path() / "mlx" / "ptx"; + cache = + std::filesystem::temp_directory_path() / "mlx" / version() / "ptx"; } if (!std::filesystem::exists(cache)) { std::error_code error; From 0eb035b4b1922a8b3c5f76092a42b83447851a93 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 10 Jul 2025 11:14:42 -0700 Subject: [PATCH 142/156] Fix type promotion in Adam with bias correction (#2350) --- python/mlx/optimizers/optimizers.py | 6 ++++-- python/tests/test_optimizers.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 09857dd0a..26b732ebd 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -526,8 +526,10 @@ class Adam(Optimizer): state["v"] = v if bias_correction: - numerator = lr / (1 - b1**step) * m - denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps + c1 = (lr / (1 - b1**step)).astype(gradient.dtype) + c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype) + numerator = c1 * m + denominator = mx.sqrt(v) * c2 + eps return parameter - numerator / denominator else: return parameter - lr * m / (mx.sqrt(v) + eps) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index e07fc8456..8f9e33679 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase): ) ) + # Test for correct gradient type propagation + params = tree_map(lambda x: x.astype(mx.float16), params) + grads = tree_map(lambda x: x.astype(mx.float16), grads) + optim = opt.Adam(1e-2, bias_correction=True) + new_params = optim.apply_gradients(grads, params) + self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params)) + @unittest.skipIf(not has_torch, "requires Torch") def test_adamw_matches_pytorch(self): mx.random.seed(0) From b6eec20260ea4bfcb75a2d98d3c2129e92c817f9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 10 Jul 2025 16:28:50 -0700 Subject: [PATCH 143/156] Fix edge check in qmm_n QuantizedLoader (#2355) --- mlx/backend/metal/kernels/quantized.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index fea6f1460..0a40cec00 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -643,14 +643,14 @@ struct QuantizedBlockLoader { return; } - if (reduction_dim == 1 && bi >= src_tile_dim.y) { + if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } - if (reduction_dim == 0 && bi >= src_tile_dim.x) { + if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } From 8347575ba1bf1ace79b80f9d74fb648ec48ff963 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 11 Jul 2025 08:54:12 +0900 Subject: [PATCH 144/156] [CUDA] Implement Scan kernel (#2347) * Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal --- mlx/backend/cuda/CMakeLists.txt | 6 + mlx/backend/cuda/device/binary_ops.cuh | 63 ++- mlx/backend/cuda/device/cexpf.cuh | 138 +++++++ mlx/backend/cuda/device/unary_ops.cuh | 32 +- mlx/backend/cuda/device/utils.cuh | 17 - mlx/backend/cuda/jit_module.cpp | 2 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/reduce/reduce_utils.cuh | 1 + mlx/backend/cuda/scan.cu | 467 +++++++++++++++++++++++ mlx/backend/metal/kernels/cexpf.h | 134 +++++++ mlx/backend/metal/kernels/unary_ops.h | 4 +- python/tests/cuda_skip.py | 5 - tests/ops_tests.cpp | 9 + 13 files changed, 815 insertions(+), 64 deletions(-) create mode 100644 mlx/backend/cuda/device/cexpf.cuh create mode 100644 mlx/backend/cuda/scan.cu create mode 100644 mlx/backend/metal/kernels/cexpf.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8130d396f..87f4cb4ae 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -35,6 +35,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu @@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# Enable calling host constexpr functions from device. This is needed because +# the constexpr version of isnan is host only. +target_compile_options( + mlx PRIVATE "$<$:--expt-relaxed-constexpr>") + # CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. # Explicitly pass this flag to suppress the warning, it is safe to set it to # true but the warning wouldn't be suppressed. diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index dc4f8e7bb..644786a92 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,10 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/cucomplex_math.cuh" -#include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" -#include #include namespace mlx::core::cu { @@ -114,36 +111,38 @@ struct LessEqual { struct LogAddExp { template __device__ T operator()(T x, T y) { - if (isnan(x) || isnan(y)) { - return cuda::std::numeric_limits::quiet_NaN(); + if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + auto max = cuCrealf(x) > cuCrealf(y) ? x : y; + auto min = cuCrealf(x) < cuCrealf(y) ? x : y; + auto min_real = cuCrealf(min); + auto max_real = cuCrealf(max); + if (!isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + return min; + } else { + return Log{}(Exp{}(min) + Exp{}(max)); + } + } else { + return Log1p{}(Exp{}(min - max)) + max; + } + } else { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); } - T maxval = max(x, y); - T minval = min(x, y); - return (minval == -cuda::std::numeric_limits::infinity() || - maxval == cuda::std::numeric_limits::infinity()) - ? maxval - : T(float(maxval) + log1p(expf(minval - maxval))); }; - - __device__ cuComplex operator()(cuComplex x, cuComplex y) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || - isnan(cuCimagf(y))) { - return { - cuda::std::numeric_limits::quiet_NaN(), - cuda::std::numeric_limits::quiet_NaN()}; - } - float inf = cuda::std::numeric_limits::infinity(); - auto maxval = x > y ? x : y; - auto minval = x < y ? x : y; - if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) - return maxval; - float m = exp(cuCrealf(minval) - cuCrealf(maxval)); - cuComplex dexp{ - m * cos(cuCimagf(minval) - cuCimagf(maxval)), - m * sin(cuCimagf(minval) - cuCimagf(maxval)), - }; - return maxval + log1p(dexp); - } }; struct Maximum { diff --git a/mlx/backend/cuda/device/cexpf.cuh b/mlx/backend/cuda/device/cexpf.cuh new file mode 100644 index 000000000..61c94c00f --- /dev/null +++ b/mlx/backend/cuda/device/cexpf.cuh @@ -0,0 +1,138 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include +#include + +namespace mlx::core::cu::detail { + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline __device__ void get_float_word(uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void get_float_word(int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void set_float_word(float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline __device__ float frexp_expf(float x, int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = expf(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = cuCrealf(z); + y = cuCimagf(z); + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return cuComplex{ + cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2}; +} + +inline __device__ cuComplex cexpf(const cuComplex& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = cuCrealf(z); + y = cuCimagf(z); + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return cuComplex{expf(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return cuComplex{cosf(y), sinf(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return cuComplex{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return cuComplex{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return cuComplex{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = expf(x); + return cuComplex{exp_x * cosf(y), exp_x * sinf(y)}; + } +} + +} // namespace mlx::core::cu::detail diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 18d769c2a..8716d3a8c 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/cuda/device/cexpf.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" @@ -150,8 +152,7 @@ struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { - auto m = exp(cuCrealf(x)); - return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; + return detail::cexpf(x); } else { return exp(x); } @@ -228,8 +229,25 @@ struct Log10 { struct Log1p { template - __device__ T operator()(T x) { - return log1p(x); + __device__ T operator()(T z) { + if constexpr (cuda::std::is_same_v) { + float x = cuCrealf(z); + float y = cuCimagf(z); + float zabs = cuCrealf(Abs{}(z)); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + float z0 = hypotf(x + 1, y); + return {logf(z0), theta}; + } + } else { + return log1p(z); + } } }; @@ -387,19 +405,19 @@ struct Tanh { } }; -__device__ cuComplex ArcCos::operator()(cuComplex x) { +inline __device__ cuComplex ArcCos::operator()(cuComplex x) { auto i = cuComplex{0.0, 1.0}; auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcSin::operator()(cuComplex x) { +inline __device__ cuComplex ArcSin::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcTan::operator()(cuComplex x) { +inline __device__ cuComplex ArcTan::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto ix = i * x; return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 83e149165..af022c141 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; -inline __device__ cuComplex log1p(cuComplex in) { - float x = cuCrealf(in); - float y = cuCimagf(in); - float zabs = sqrt(x * x + y * y); - float theta = atan2f(y, x + 1); - if (zabs < 0.5f) { - float r = x * (2 + x) + y * y; - if (r == 0) { // handle underflow - return {x, theta}; - } - return {0.5f * log1pf(r), theta}; - } else { - auto z0 = sqrt((x + 1) * (x + 1) + y * y); - return {log(z0), theta}; - } -} - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index e6dbd35da..834e4a3d1 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "cexpf.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", @@ -177,6 +178,7 @@ constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, + jit_source_cexpf, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index a8496b958..3a3f8ff54 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -82,7 +82,6 @@ NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(Scan) NO_GPU(SegmentedMM) NO_GPU_MULTI(SVD) NO_GPU(Inverse) diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index ccd7ae48d..d993bacbb 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -4,6 +4,7 @@ #include +#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu new file mode 100644 index 000000000..7a26ee161 --- /dev/null +++ b/mlx/backend/cuda/scan.cu @@ -0,0 +1,467 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct ScanResult { + using type = T; +}; + +template <> +struct ScanResult { + using type = int32_t; +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +template +inline __device__ void +load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { + int remaining = size - index * N_READS; + if constexpr (reverse) { + in += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = + (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = cast_to(in[i]); + } + } + } else { + in += index * N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[i] = (i < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[i] = cast_to(in[i]); + } + } + } +} + +template +inline __device__ void +store_values(int index, T* out, T (&values)[N_READS], int size) { + int start = index * N_READS + offset; + int remaining = size - start; + if constexpr (reverse) { + out += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (N_READS - i - 1 < remaining) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + out += start; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (i < remaining) { + out[i] = values[i]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + __shared__ U warp_sums[WARP_SIZE]; + + Op op; + U init = ReduceInit::value(); + U prefix = init; + + // Scan per block. + for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { + int32_t index = r * block.size() + block.thread_rank(); + U values[N_READS]; + load_values(index, in, values, axis_size, init); + + // Compute an inclusive scan per thread. + for (int i = 1; i < N_READS; ++i) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums. + U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op); + if (warp.thread_rank() == 0) { + prev_thread_sum = init; + } + + // Write wrap's sum to shared memory. + if (warp.thread_rank() == WARP_SIZE - 1) { + warp_sums[warp.meta_group_rank()] = + op(prev_thread_sum, values[N_READS - 1]); + } + block.sync(); + + // Compute exclusive scan of warp sums. + if (warp.meta_group_rank() == 0) { + U prev_warp_sum = + cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op); + if (warp.thread_rank() == 0) { + prev_warp_sum = init; + } + warp_sums[warp.thread_rank()] = prev_warp_sum; + } + block.sync(); + + // Compute the output. + for (int i = 0; i < N_READS; ++i) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], warp_sums[warp.meta_group_rank()]); + values[i] = op(values[i], prev_thread_sum); + } + + // Write the values. + if (inclusive) { + store_values(index, out, values, axis_size); + } else { + store_values(index, out, values, axis_size); + if (reverse) { + if (block.thread_rank() == 0 && index == 0) { + out[axis_size - 1] = init; + } + } else { + if (block.thread_rank() == 0 && index == 0) { + out[0] = init; + } + } + } + block.sync(); + + // Share the prefix. + if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && + (warp.thread_rank() == WARP_SIZE - 1)) { + warp_sums[0] = values[N_READS - 1]; + } + block.sync(); + prefix = warp_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + int BM, + int BN, + bool inclusive, + bool reverse> +__global__ void strided_scan( + const T* in, + U* out, + int32_t axis_size, + int64_t stride, + int64_t stride_blocks) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); + constexpr int n_warps = BN / N_READS; + constexpr int n_scans = BN / n_warps; + + __shared__ U read_buffer[BM * BN_pad]; + + Op op; + U init = ReduceInit::value(); + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; ++i) { + prefix[i] = init; + } + + // Compute offsets. + int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride; + int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN; + uint read_offset_y = (block.thread_rank() * N_READS) / BN; + uint read_offset_x = (block.thread_rank() * N_READS) % BN; + uint scan_offset_y = warp.thread_rank(); + uint scan_offset_x = warp.meta_group_rank() * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; + U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread. + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM. + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = init; + } + } + } + block.sync(); + + // Read strided into registers. + for (int i = 0; i < n_scans; ++i) { + values[i] = read_from[i]; + } + + // Perform the scan. + for (int i = 0; i < n_scans; ++i) { + values[i] = cg::inclusive_scan(warp, values[i], op); + values[i] = op(values[i], prefix[i]); + prefix[i] = warp.shfl(values[i], WARP_SIZE - 1); + } + + // Write to SM. + for (int i = 0; i < n_scans; ++i) { + read_from[i] = values[i]; + } + block.sync(); + + // Write to device memory. + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} + +} // namespace cu + +template +void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { + if (scan_op == Scan::ReduceType::Max) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Min) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Sum) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Prod) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::LogAddExp) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +template +const char* op_to_string() { + if (cuda::std::is_same_v) { + return "Max"; + } else if (cuda::std::is_same_v) { + return "Min"; + } else if (cuda::std::is_same_v) { + return "Sum"; + } else if (cuda::std::is_same_v) { + return "Prod"; + } else if (cuda::std::is_same_v) { + return "LogAddExp"; + } else { + throw std::invalid_argument("Unknown op."); + } +} + +template +constexpr bool supports_scan_op() { + if constexpr (cuda::std::is_same_v) { + return is_inexact_v; + } else { + return true; + } +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + in = std::move(arr_copy); + out.copy_shared_buffer(in); + } + + constexpr int N_READS = 4; + int32_t axis_size = in.shape(axis_); + bool contiguous = in.strides()[axis_] == 1; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + using Op = MLX_GET_TYPE(scan_op_tag); + if constexpr (supports_scan_op) { + using U = typename cu::ScanResult::type; + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { + if (contiguous) { + auto kernel = cu::contiguous_scan< + T, + U, + Op, + N_READS, + inclusive.value, + reverse.value>; + int block_dim = cuda::ceil_div(axis_size, N_READS); + block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; + block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); + encoder.add_kernel_node( + kernel, + in.data_size() / axis_size, + block_dim, + in.data(), + out.data(), + axis_size); + } else { + constexpr int BM = WARP_SIZE; + constexpr int BN = WARP_SIZE; + auto kernel = cu::strided_scan< + T, + U, + Op, + N_READS, + BM, + BN, + inclusive.value, + reverse.value>; + int64_t stride = in.strides()[axis_]; + int64_t stride_blocks = cuda::ceil_div(stride, BN); + dim3 num_blocks = get_2d_grid_dims( + in.shape(), in.strides(), axis_size * stride); + if (num_blocks.x * stride_blocks <= UINT32_MAX) { + num_blocks.x *= stride_blocks; + } else { + num_blocks.y *= stride_blocks; + } + int block_dim = (BN / N_READS) * WARP_SIZE; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim, + in.data(), + out.data(), + axis_size, + stride, + stride_blocks); + } + }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do scan op {} on inputs of {} with result of {}.", + op_to_string(), + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/cexpf.h b/mlx/backend/metal/kernels/cexpf.h new file mode 100644 index 000000000..b45fe6a2f --- /dev/null +++ b/mlx/backend/metal/kernels/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 09d9f6605..b34bc44ba 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" @@ -178,8 +179,7 @@ struct Exp { return metal::precise::exp(x); }; complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + return cexpf(x); } }; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index afd48bd03..005c612ff 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -13,11 +13,6 @@ cuda_skip = { "TestBlas.test_gather_mm_sorted", # Segmented matmul NYI "TestBlas.test_segmented_mm", - # Scan NYI - "TestArray.test_api", - "TestAutograd.test_cumprod_grad", - "TestOps.test_scans", - "TestOps.test_logcumsumexp", # Hadamard NYI "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1a9781c7c..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); CHECK(allclose(exp(x), expected).item()); + + // Complex of -inf + constexpr float inf = std::numeric_limits::infinity(); + x = array(complex64_t{-inf, -inf}); + CHECK_EQ(exp(x).item(), complex64_t{0, 0}); } // Test expm1 @@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") { x = array(-inf); y = array(inf); CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(complex64_t{1, 1}); + y = array(complex64_t{-inf, -inf}); + CHECK_EQ(logaddexp(x, y).item(), complex64_t{1, 1}); } TEST_CASE("test broadcast") { From 42cc9cfbc7f468d2d052fc16c09b5da01198969f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 11 Jul 2025 10:59:35 -0700 Subject: [PATCH 145/156] fix copy dispatch (#2360) --- mlx/backend/metal/copy.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 8123b793e..915fc69fd 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -86,7 +86,7 @@ void copy_gpu_inplace( } } else { work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); - if (work_per_thread > 1) { + if (!large && work_per_thread > 1) { kernel_name += "n"; } } From 6325f60d52dd0ef68f9a811a69fe51b269c154c7 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 12 Jul 2025 10:45:37 +0900 Subject: [PATCH 146/156] [CUDA] Bundle CCCL for JIT compilation (#2357) * Ship CCCL for JIT compilation * Remove cexpf --- mlx/backend/common/utils.cpp | 13 +++ mlx/backend/common/utils.h | 4 + mlx/backend/cuda/CMakeLists.txt | 4 + mlx/backend/cuda/device/atomic_ops.cuh | 5 - mlx/backend/cuda/device/cexpf.cuh | 138 ------------------------- mlx/backend/cuda/device/unary_ops.cuh | 5 +- mlx/backend/cuda/jit_module.cpp | 26 ++++- mlx/backend/metal/device.cpp | 13 +-- mlx/backend/metal/device.h | 16 --- 9 files changed, 48 insertions(+), 176 deletions(-) delete mode 100644 mlx/backend/cuda/device/cexpf.cuh diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 9766e5e0c..942f9576e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -1,5 +1,7 @@ // Copyright © 2023-2024 Apple Inc. +#include + #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -11,6 +13,17 @@ std::string get_primitive_string(Primitive* primitive) { return op_t.str(); } +std::filesystem::path current_binary_dir() { + static std::filesystem::path binary_dir = []() { + Dl_info info; + if (!dladdr(reinterpret_cast(¤t_binary_dir), &info)) { + throw std::runtime_error("Unable to get current binary dir."); + } + return std::filesystem::path(info.dli_fname).parent_path(); + }(); + return binary_dir; +} + std::tuple> collapse_contiguous_dims( const Shape& shape, const std::vector& strides, diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 114878846..543868e36 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include @@ -11,6 +12,9 @@ namespace mlx::core { std::string get_primitive_string(Primitive* primitive); +// Return the directory that contains current shared library. +std::filesystem::path current_binary_dir(); + inline int64_t elem_to_loc(int elem, const Shape& shape, const Strides& strides) { int64_t loc = 0; diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 87f4cb4ae..29f2eeab6 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -125,3 +125,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) + +# Install CCCL headers for JIT. +install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh index b6915606e..e0d3c3eac 100644 --- a/mlx/backend/cuda/device/atomic_ops.cuh +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -58,12 +58,7 @@ inline __device__ void atomic_add(cuComplex* out, cuComplex val) { inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { #if __CUDA_ARCH__ < 800 -#if CCCL_VERSION >= 2008000 atomic_add_general(out, val); -#else - bool cccl_version_too_old_for_bfloat16_atomic_add = false; - assert(cccl_version_too_old_for_bfloat16_atomic_add); -#endif #else atomicAdd(out, val); #endif diff --git a/mlx/backend/cuda/device/cexpf.cuh b/mlx/backend/cuda/device/cexpf.cuh deleted file mode 100644 index 61c94c00f..000000000 --- a/mlx/backend/cuda/device/cexpf.cuh +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright © 2025 Apple Inc. -// Copyright © 2008-2013 NVIDIA Corporation -// Copyright © 2013 Filipe RNC Maia -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Forked from -// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h - -// TODO: We should use thrust::exp but the thrust header in old CUDA versions -// can not be used in JIT. - -#pragma once - -#include -#include - -namespace mlx::core::cu::detail { - -using ieee_float_shape_type = union { - float value; - uint32_t word; -}; - -inline __device__ void get_float_word(uint32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline __device__ void get_float_word(int32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline __device__ void set_float_word(float& d, uint32_t i) { - ieee_float_shape_type sf_u; - sf_u.word = (i); - (d) = sf_u.value; -} - -inline __device__ float frexp_expf(float x, int* expt) { - const uint32_t k = 235; - const float kln2 = 162.88958740F; - - float exp_x; - uint32_t hx; - - exp_x = expf(x - kln2); - get_float_word(hx, exp_x); - *expt = (hx >> 23) - (0x7f + 127) + k; - set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); - return exp_x; -} - -inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) { - float x, y, exp_x, scale1, scale2; - int ex_expt, half_expt; - - x = cuCrealf(z); - y = cuCimagf(z); - exp_x = frexp_expf(x, &ex_expt); - expt += ex_expt; - - half_expt = expt / 2; - set_float_word(scale1, (0x7f + half_expt) << 23); - half_expt = expt - half_expt; - set_float_word(scale2, (0x7f + half_expt) << 23); - - return cuComplex{ - cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2}; -} - -inline __device__ cuComplex cexpf(const cuComplex& z) { - float x, y, exp_x; - uint32_t hx, hy; - - const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; - - x = cuCrealf(z); - y = cuCimagf(z); - - get_float_word(hy, y); - hy &= 0x7fffffff; - - /* cexp(x + I 0) = exp(x) + I 0 */ - if (hy == 0) { - return cuComplex{expf(x), y}; - } - get_float_word(hx, x); - /* cexp(0 + I y) = cos(y) + I sin(y) */ - if ((hx & 0x7fffffff) == 0) { - return cuComplex{cosf(y), sinf(y)}; - } - if (hy >= 0x7f800000) { - if ((hx & 0x7fffffff) != 0x7f800000) { - /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ - return cuComplex{y - y, y - y}; - } else if (hx & 0x80000000) { - /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ - return cuComplex{0.0, 0.0}; - } else { - /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ - return cuComplex{x, y - y}; - } - } - - if (hx >= exp_ovfl && hx <= cexp_ovfl) { - /* - * x is between 88.7 and 192, so we must scale to avoid - * overflow in expf(x). - */ - return ldexp_cexpf(z, 0); - } else { - /* - * Cases covered here: - * - x < exp_ovfl and exp(x) won't overflow (common case) - * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 - * - x = +-Inf (generated by exp()) - * - x = NaN (spurious inexact exception from y) - */ - exp_x = expf(x); - return cuComplex{exp_x * cosf(y), exp_x * sinf(y)}; - } -} - -} // namespace mlx::core::cu::detail diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 8716d3a8c..447569eeb 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,12 +2,12 @@ #pragma once -#include "mlx/backend/cuda/device/cexpf.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include +#include namespace mlx::core::cu { @@ -152,7 +152,8 @@ struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { - return detail::cexpf(x); + auto r = exp(cuda::std::complex{cuCrealf(x), cuCimagf(x)}); + return cuComplex{r.real(), r.imag()}; } else { return exp(x); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 834e4a3d1..4ce79999e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace mlx::core::cu { @@ -50,6 +51,16 @@ const std::string& cuda_home() { return home; } +// Return the location of CCCL headers shipped with the distribution. +bool get_cccl_include(std::string* out) { + auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl"; + if (!std::filesystem::exists(cccl_headers)) { + return false; + } + *out = fmt::format("--include-path={}", cccl_headers.string()); + return true; +} + // Get the cache directory for storing compiled results. const std::filesystem::path& ptx_cache_dir() { static std::filesystem::path cache = []() -> std::filesystem::path { @@ -161,7 +172,6 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", - INCLUDE_PREFIX "cexpf.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", @@ -178,7 +188,6 @@ constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, - jit_source_cexpf, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, @@ -217,16 +226,23 @@ JitModule::JitModule( } // Compile program. + std::vector args; bool use_sass = compiler_supports_device_sass(device); std::string compute = fmt::format( "--gpu-architecture={}_{}{}", use_sass ? "sm" : "compute", device.compute_capability_major(), device.compute_capability_minor()); - std::string include = fmt::format("--include-path={}/include", cuda_home()); - const char* args[] = {compute.c_str(), include.c_str()}; + args.push_back(compute.c_str()); + std::string cccl_include; + if (get_cccl_include(&cccl_include)) { + args.push_back(cccl_include.c_str()); + } + std::string cuda_include = + fmt::format("--include-path={}/include", cuda_home()); + args.push_back(cuda_include.c_str()); nvrtcResult compile_result = - nvrtcCompileProgram(prog, std::size(args), args); + nvrtcCompileProgram(prog, args.size(), args.data()); if (compile_result != NVRTC_SUCCESS) { size_t log_size; CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 88835eb75..e22d9da2d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,20 +1,18 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { namespace { @@ -80,12 +78,7 @@ MTL::Library* try_load_bundle( std::pair load_colocated_library( MTL::Device* device, const std::string& relative_path) { - std::string binary_dir = get_binary_directory(); - if (binary_dir.size() == 0) { - return {nullptr, nullptr}; - } - - auto path = fs::path(binary_dir) / relative_path; + auto path = current_binary_dir() / relative_path; if (!path.has_extension()) { path.replace_extension(".metallib"); } @@ -197,7 +190,7 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_binary_directory() << "/" + << "We attempted to load it from <" << current_binary_dir() << "/" << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index f87a8c48b..52595e6e6 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -3,8 +3,6 @@ #pragma once #include -#include -#include #include #include #include @@ -15,22 +13,8 @@ #include "mlx/array.h" #include "mlx/device.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { -// Note, this function must be left inline in a header so that it is not -// dynamically linked. -inline std::string get_binary_directory() { - Dl_info info; - std::string directory; - int success = dladdr((void*)get_binary_directory, &info); - if (success) { - directory = fs::path(info.dli_fname).remove_filename().c_str(); - } - return directory; -} - using MTLFCList = std::vector>; From 2d3c26c56557dce4da57ea7455952596640d74ff Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 13 Jul 2025 06:24:45 +0900 Subject: [PATCH 147/156] [CUDA] Do not put kernels in annoymous namespace (#2362) --- mlx/backend/cuda/event.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index afa032a83..f51d2f2e3 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -90,8 +90,6 @@ bool CudaEvent::completed() const { // SharedEvent implementations /////////////////////////////////////////////////////////////////////////////// -namespace { - __host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { uint64_t current; while ((current = ac->load()) < value) { @@ -112,8 +110,6 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { event_signal(ac, value); } -} // namespace - SharedEvent::SharedEvent() { // Allocate cuda::atomic on managed memory. Atomic* ac; From 5201df50309b9797c3b673c971ae89145e9e808a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 14 Jul 2025 13:11:16 -0700 Subject: [PATCH 148/156] Fix imag() vjp (#2367) --- mlx/primitives.cpp | 14 +++++++++----- python/tests/test_autograd.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index eb5d9d6b3..72affbd34 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2459,7 +2459,7 @@ std::vector Imag::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); return {multiply( - array(complex64_t{0.0f, -1.0f}, primals[0].dtype()), + array(complex64_t{0.0f, 1.0f}, primals[0].dtype()), cotangents[0], stream())}; } @@ -2788,15 +2788,19 @@ std::vector Matmul::vjp( std::vector reorder(cotan.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::iter_swap(reorder.end() - 1, reorder.end() - 2); + auto& s = stream(); + + auto complex_transpose = [&](const array& x) { + return transpose(conjugate(x, s), reorder, s); + }; + for (auto arg : argnums) { if (arg == 0) { // M X N * (K X N).T -> M X K - vjps.push_back( - matmul(cotan, transpose(primals[1], reorder, stream()), stream())); + vjps.push_back(matmul(cotan, complex_transpose(primals[1]), s)); } else { // (M X K).T * M X N -> K X N - vjps.push_back( - matmul(transpose(primals[0], reorder, stream()), cotan, stream())); + vjps.push_back(matmul(complex_transpose(primals[0]), cotan, s)); } } return vjps; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 7973d79be..5722071f6 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -606,7 +606,7 @@ class TestAutograd(mlx_tests.MLXTestCase): x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j]) dfdx = mx.grad(fun)(x) - self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) + self.assertTrue(mx.allclose(dfdx, 2j * mx.ones_like(x))) def test_flatten_unflatten_vjps(self): def fun(x): From d34f887abc7969fcf001f621504a100c7f98dbe0 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 15 Jul 2025 06:06:35 +0900 Subject: [PATCH 149/156] Add Primitive::name and remove Primitive::print (#2365) --- docs/src/dev/extensions.rst | 8 +- examples/extensions/axpby/axpby.h | 6 +- mlx/backend/common/utils.cpp | 7 - mlx/backend/common/utils.h | 2 - mlx/backend/cpu/compiled.cpp | 2 +- mlx/backend/cuda/binary.cu | 30 ++- mlx/backend/cuda/binary_two.cu | 6 +- mlx/backend/cuda/compiled.cpp | 4 +- mlx/backend/cuda/unary.cu | 23 +- mlx/backend/metal/binary.cpp | 28 +-- mlx/backend/metal/binary.h | 8 +- mlx/backend/metal/compiled.cpp | 4 +- mlx/backend/metal/jit_kernels.cpp | 18 +- mlx/backend/metal/kernels.h | 14 +- mlx/backend/metal/nojit_kernels.cpp | 8 +- mlx/backend/metal/ternary.cpp | 8 +- mlx/backend/metal/ternary.h | 4 +- mlx/backend/metal/unary.cpp | 16 +- mlx/backend/metal/unary.h | 4 +- mlx/backend/metal/utils.h | 2 +- mlx/compile.cpp | 15 +- mlx/distributed/primitives.h | 27 +-- mlx/export.cpp | 4 +- mlx/fast_primitives.h | 16 +- mlx/graph_utils.cpp | 4 +- mlx/linalg.cpp | 2 +- mlx/primitives.cpp | 37 +--- mlx/primitives.h | 332 ++++++++++++++-------------- mlx/transforms.cpp | 2 +- python/src/linalg.cpp | 2 +- python/src/metal.cpp | 2 +- python/src/ops.cpp | 2 +- 32 files changed, 307 insertions(+), 340 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 03f1c2163..5a4de8123 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -138,13 +138,13 @@ more concrete: * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - virtual std::pair, std::vector> vmap( + std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 26f80961c..e6da491f8 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -74,9 +74,9 @@ class Axpby : public mx::Primitive { const std::vector& inputs, const std::vector& axes) override; - /** Print the primitive. */ - void print(std::ostream& os) override { - os << "Axpby"; + /** The name of primitive. */ + const char* name() const override { + return "Axpby"; } /** Equivalence check **/ diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 942f9576e..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -3,16 +3,9 @@ #include #include "mlx/backend/common/utils.h" -#include "mlx/primitives.h" namespace mlx::core { -std::string get_primitive_string(Primitive* primitive) { - std::ostringstream op_t; - primitive->print(op_t); - return op_t.str(); -} - std::filesystem::path current_binary_dir() { static std::filesystem::path binary_dir = []() { Dl_info info; diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 543868e36..0f9846086 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -10,8 +10,6 @@ namespace mlx::core { -std::string get_primitive_string(Primitive* primitive); - // Return the directory that contains current shared library. std::filesystem::path current_binary_dir(); diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index d0bfb4f45..d85114987 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -231,7 +231,7 @@ inline void build_kernel( os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" << namer.get_name(x.inputs()[0]) << ");" << std::endl; } else { - x.primitive().print(os); + os << x.primitive().name(); os << "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index fc5b8c496..c8586e638 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -177,7 +177,7 @@ template void binary_op_gpu_inplace( const std::vector& inputs, array& out, - std::string_view op, + const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; @@ -291,7 +291,7 @@ template void binary_op_gpu( const std::vector& inputs, array& out, - std::string_view op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -300,11 +300,11 @@ void binary_op_gpu( binary_op_gpu_inplace(inputs, out, op, s); } -#define BINARY_GPU(func) \ - void func::eval_gpu(const std::vector& inputs, array& out) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = out.primitive().stream(); \ - binary_op_gpu(inputs, out, get_primitive_string(this), s); \ +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, name(), s); \ } BINARY_GPU(Add) @@ -328,33 +328,31 @@ BINARY_GPU(Subtract) void Equal::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Equal::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); if (equal_nan_) { - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); } else { - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); } } void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, op, s); + binary_op_gpu(inputs, out, name(), s); break; } } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 4b6e24581..0918c579f 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -184,7 +184,7 @@ template void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - std::string_view op, + const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; @@ -314,7 +314,7 @@ template void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, - std::string_view op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -329,7 +329,7 @@ void DivMod::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); - binary_two_op_gpu(inputs, outputs, get_primitive_string(this), s); + binary_two_op_gpu(inputs, outputs, name(), s); } } // namespace mlx::core diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 21257e5dd..2f3990b90 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -106,9 +106,7 @@ struct FusedKernelBuilder { value = fmt::format( "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); } else { - std::ostringstream ss; - x.primitive().print(ss); - value = ss.str(); + value = x.primitive().name(); value += "{}("; for (size_t i = 0; i < x.inputs().size() - 1; ++i) { value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 1fe1b557b..0d2754ef0 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -102,7 +102,7 @@ template void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { auto& in = inputs[0]; if (in.size() == 0) { @@ -178,17 +178,17 @@ template void unary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); } -#define UNARY_GPU(func) \ - void func::eval_gpu(const std::vector& inputs, array& out) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = out.primitive().stream(); \ - unary_op_gpu(inputs, out, get_primitive_string(this), s); \ +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ } UNARY_GPU(Abs) @@ -224,16 +224,15 @@ UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Log::eval_gpu"); auto& s = out.primitive().stream(); - auto op = get_primitive_string(this); switch (base_) { case Base::e: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; case Base::two: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; case Base::ten: - unary_op_gpu(inputs, out, op, s); + unary_op_gpu(inputs, out, name(), s); break; } } @@ -244,7 +243,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { const auto& in = inputs[0]; auto& s = out.primitive().stream(); if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, get_primitive_string(this), s); + unary_op_gpu(inputs, out, name(), s); } else { // No-op integer types out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 54aaf153c..8c0e8c333 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -7,20 +7,20 @@ #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - binary_op_gpu(inputs, out, get_primitive_string(this)); \ + binary_op_gpu(inputs, out, name()); \ } #define BINARY_GPU_MULTI(func) \ void func::eval_gpu( \ const std::vector& inputs, std::vector& outputs) { \ - binary_op_gpu(inputs, outputs, get_primitive_string(this)); \ + binary_op_gpu(inputs, outputs, name()); \ } namespace mlx::core { std::string get_kernel_name( BinaryOpType bopt, - const std::string& op, + const char* op, const array& a, bool large, int ndim, @@ -65,7 +65,7 @@ std::string get_kernel_name( void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -165,7 +165,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -179,7 +179,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op) { + const char* op) { auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, op, s); } @@ -187,7 +187,7 @@ void binary_op_gpu( void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { std::vector outputs = {out}; binary_op_gpu_inplace(inputs, outputs, op, s); @@ -196,7 +196,7 @@ void binary_op_gpu_inplace( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; @@ -209,7 +209,7 @@ void binary_op_gpu( void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op) { + const char* op) { auto& s = out.primitive().stream(); binary_op_gpu(inputs, out, op, s); } @@ -237,19 +237,19 @@ BINARY_GPU(Subtract) void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, get_primitive_string(this)); + binary_op_gpu(inputs, out, name()); break; } } diff --git a/mlx/backend/metal/binary.h b/mlx/backend/metal/binary.h index 8552c1e07..0341a2f83 100644 --- a/mlx/backend/metal/binary.h +++ b/mlx/backend/metal/binary.h @@ -9,25 +9,25 @@ namespace mlx::core { void binary_op_gpu( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string& op, + const char* op, const Stream& s); void binary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string& op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 88edc6baa..eb51ab750 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -212,9 +212,7 @@ inline void build_kernel( get_type_string(x.dtype()), namer.get_name(x.inputs()[0])); } else { - std::ostringstream ss; - x.primitive().print(ss); - os += ss.str(); + os += x.primitive().name(); os += "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i])); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index fd0e0db09..6ae72e0aa 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -8,12 +8,6 @@ using namespace fmt::literals; namespace mlx::core { -std::string op_name(const array& arr) { - std::ostringstream op_t; - arr.primitive().print(op_t); - return op_t.str(); -} - MTL::ComputePipelineState* get_arange_kernel( metal::Device& d, const std::string& kernel_name, @@ -33,7 +27,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto in_t = get_type_string(in_type); @@ -58,10 +52,10 @@ MTL::ComputePipelineState* get_unary_kernel( } void append_binary_kernels( - const std::string lib_name, + const std::string& lib_name, Dtype in_type, Dtype out_type, - const std::string op, + const char* op, std::string& kernel_source) { const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, @@ -112,7 +106,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; @@ -129,7 +123,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); @@ -144,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op) { + const char* op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { auto t_str = get_type_string(type); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 794c67bdc..ca29ca52e 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -19,27 +19,27 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, Dtype in_type, Dtype out_type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype type, - const std::string op); + const char* op); MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, @@ -257,8 +257,10 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( // Create a GPU kernel template definition for JIT compilation template -std::string -get_template_definition(std::string name, std::string func, Args... args) { +std::string get_template_definition( + std::string_view name, + std::string_view func, + Args... args) { std::ostringstream s; s << func << "<"; bool first = true; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 32d3e75f7..a689a793e 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -18,7 +18,7 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -27,7 +27,7 @@ MTL::ComputePipelineState* get_binary_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -36,7 +36,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( const std::string& kernel_name, Dtype, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } @@ -44,7 +44,7 @@ MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, - const std::string) { + const char*) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 22f2a1985..b2b9e3337 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -11,7 +11,7 @@ namespace mlx::core { void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { assert(inputs.size() == 3); auto& a = inputs[0]; @@ -128,7 +128,7 @@ void ternary_op_gpu_inplace( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; @@ -141,13 +141,13 @@ void ternary_op_gpu( void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, op, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op_gpu(inputs, out, get_primitive_string(this)); + ternary_op_gpu(inputs, out, name()); } } // namespace mlx::core diff --git a/mlx/backend/metal/ternary.h b/mlx/backend/metal/ternary.h index 0834140b8..91c6fbbeb 100644 --- a/mlx/backend/metal/ternary.h +++ b/mlx/backend/metal/ternary.h @@ -9,13 +9,13 @@ namespace mlx::core { void ternary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 0b118b72f..48f85635b 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -8,7 +8,7 @@ #define UNARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ - unary_op_gpu(inputs, out, get_primitive_string(this)); \ + unary_op_gpu(inputs, out, name()); \ } namespace mlx::core { @@ -16,7 +16,7 @@ namespace mlx::core { void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { auto& in = inputs[0]; bool contig = in.flags().contiguous; @@ -98,7 +98,7 @@ void unary_op_gpu_inplace( void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s) { set_unary_output_data(inputs[0], out); unary_op_gpu_inplace(inputs, out, op, s); @@ -107,7 +107,7 @@ void unary_op_gpu( void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op) { + const char* op) { auto& s = out.primitive().stream(); unary_op_gpu(inputs, out, op, s); } @@ -146,13 +146,13 @@ UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::two: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; case Base::ten: - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); break; } } @@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, get_primitive_string(this)); + unary_op_gpu(inputs, out, name()); } else { // No-op integer types out.copy_shared_buffer(in); diff --git a/mlx/backend/metal/unary.h b/mlx/backend/metal/unary.h index 19057076b..1d6ecf027 100644 --- a/mlx/backend/metal/unary.h +++ b/mlx/backend/metal/unary.h @@ -9,13 +9,13 @@ namespace mlx::core { void unary_op_gpu( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op, + const char* op, const Stream& s); } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index a491521a0..e7784e599 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -40,7 +40,7 @@ inline void debug_set_primitive_buffer_label( if (auto cbuf_label = command_buffer->label(); cbuf_label) { label << cbuf_label->utf8String(); } - primitive.print(label); + label << primitive.name(); command_buffer->setLabel(make_string(label)); #endif } diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 0cb3b5a85..91743ec04 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -107,7 +107,7 @@ Compiled::Compiled( // name and type of output os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); // computation performed - a.primitive().print(os); + os << a.primitive().name(); // name of inputs to the function for (auto& inp : a.inputs()) { os << namer.get_name(inp); @@ -170,11 +170,16 @@ bool Compiled::is_equivalent(const Primitive& other) const { }); } -void Compiled::print(std::ostream& os) { - os << "Compiled"; - for (auto& a : tape_) { - a.primitive().print(os); +const char* Compiled::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "Compiled"; + for (auto& a : tape_) { + os << a.primitive().name(); + } + name_ = os.str(); } + return name_.c_str(); } std::vector Compiled::output_shapes(const std::vector& inputs) { diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 7320e6cb6..7ad00a0d6 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -45,27 +45,22 @@ class AllReduce : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; + return "And AllReduce"; case Or: - os << "And"; - break; + return "Or AllReduce"; case Sum: - os << "Sum"; - break; + return "Sum AllReduce"; case Prod: - os << "Prod"; - break; + return "Prod AllReduce"; case Min: - os << "Min"; - break; + return "Min AllReduce"; case Max: - os << "Max"; - break; + return "Max AllReduce"; } - os << " AllReduce"; + return ""; } private: @@ -94,7 +89,7 @@ class AllGather : public DistPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(AllGather); + DEFINE_NAME(AllGather); }; class Send : public DistPrimitive { @@ -110,7 +105,7 @@ class Send : public DistPrimitive { const std::vector& inputs, const std::vector& axes) override; - DEFINE_PRINT(Send); + DEFINE_NAME(Send); private: int dst_; @@ -126,7 +121,7 @@ class Recv : public DistPrimitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(Recv); + DEFINE_NAME(Recv); private: int src_; diff --git a/mlx/export.cpp b/mlx/export.cpp index 552c35cfb..8eb385bb1 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -354,9 +354,7 @@ struct PrimitiveFactory { void save(Writer& os, const std::shared_ptr& p) { serialize(os, p->stream()); - std::ostringstream pout; - p->print(pout); - auto name = pout.str(); + std::string name = p->name(); name = name.substr(0, name.find(' ')); if (auto it = name_remap.find(name); it != name_remap.end()) { name = it->second; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 51050ea50..52135adad 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -58,7 +58,7 @@ class RMSNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RMSNorm) + DEFINE_NAME(RMSNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() @@ -85,7 +85,7 @@ class RMSNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(RMSNormVJP) + DEFINE_NAME(RMSNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); @@ -118,7 +118,7 @@ class LayerNorm : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(LayerNorm) + DEFINE_NAME(LayerNorm) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -144,7 +144,7 @@ class LayerNormVJP : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LayerNormVJP) + DEFINE_NAME(LayerNormVJP) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(nullptr, eps_); @@ -186,7 +186,7 @@ class RoPE : public Custom { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(RoPE) + DEFINE_NAME(RoPE) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -233,7 +233,7 @@ class ScaledDotProductAttention : public Custom { void eval_gpu(const std::vector& inputs, array& out); bool is_equivalent(const Primitive& other) const override; - DEFINE_PRINT(ScaledDotProductAttention); + DEFINE_NAME(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return std::make_tuple(nullptr, scale_, do_causal_); @@ -263,7 +263,7 @@ class AffineQuantize : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(AffineQuantize); + DEFINE_NAME(AffineQuantize); bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; @@ -311,7 +311,7 @@ class CustomKernel : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(CustomKernel); + DEFINE_NAME(CustomKernel); private: std::string source_; diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp index 29373f266..854881bc9 100644 --- a/mlx/graph_utils.cpp +++ b/mlx/graph_utils.cpp @@ -93,7 +93,7 @@ void print_graph( os << "\n"; for (auto& arr : tape) { - arr.primitive().print(os); + os << arr.primitive().name(); os << " "; print_arrs(arr.inputs()); os << " -> "; @@ -143,7 +143,7 @@ void export_to_dot( os << "{ "; os << x.primitive_id(); os << " [label =\""; - x.primitive().print(os); + os << x.primitive().name(); os << "\", shape=rectangle]"; os << "; }" << std::endl; // Arrows to primitive's inputs diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index ff3208e1e..e8a9e430e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -500,7 +500,7 @@ array cross( void validate_eig( const array& a, const StreamOrDevice& stream, - const std::string fname) { + const std::string& fname) { check_cpu_stream(stream, fname); check_float_or_complex(a.dtype(), fname); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 72affbd34..cf0e6ef0d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -181,7 +181,7 @@ std::vector Primitive::jvp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::jvp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -193,7 +193,7 @@ std::vector Primitive::vjp( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vjp] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -203,7 +203,7 @@ std::pair, std::vector> Primitive::vmap( const std::vector&) { std::ostringstream msg; msg << "[Primitive::vmap] Not implemented for "; - print(msg); + msg << name(); msg << "."; throw std::invalid_argument(msg.str()); } @@ -211,7 +211,7 @@ std::pair, std::vector> Primitive::vmap( std::vector Primitive::output_shapes(const std::vector&) { std::ostringstream msg; msg << "[Primitive::output_shapes] "; - this->print(msg); + msg << name(); msg << " cannot infer output shapes."; throw std::invalid_argument(msg.str()); } @@ -743,26 +743,6 @@ bool BitwiseBinary::is_equivalent(const Primitive& other) const { return op_ == a_other.op_; } -void BitwiseBinary::print(std::ostream& os) { - switch (op_) { - case BitwiseBinary::And: - os << "BitwiseAnd"; - break; - case BitwiseBinary::Or: - os << "BitwiseOr"; - break; - case BitwiseBinary::Xor: - os << "BitwiseXor"; - break; - case BitwiseBinary::LeftShift: - os << "LeftShift"; - break; - case BitwiseBinary::RightShift: - os << "RightShift"; - break; - } -} - std::pair, std::vector> BitwiseBinary::vmap( const std::vector& inputs, const std::vector& axes) { @@ -5375,8 +5355,13 @@ std::pair, std::vector> View::vmap( return {{view(inputs[0], dtype_, stream())}, axes}; } -void View::print(std::ostream& os) { - os << "View " << dtype_; +const char* View::name() const { + if (name_.empty()) { + std::ostringstream os; + os << "View " << dtype_; + name_ = os.str(); + } + return name_.c_str(); } bool View::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index 3d3202aaa..d482a1bf9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -26,9 +26,9 @@ const std::vector& argnums, \ const std::vector& outputs) override; -#define DEFINE_PRINT(PRIMITIVE) \ - void print(std::ostream& os) override { \ - os << #PRIMITIVE; \ +#define DEFINE_NAME(PRIMITIVE) \ + const char* name() const override { \ + return #PRIMITIVE; \ } #define DEFINE_DEFAULT_IS_EQUIVALENT() \ @@ -100,8 +100,8 @@ class Primitive { const std::vector& inputs, const std::vector& axes); - /** Print the primitive. */ - virtual void print(std::ostream& os) = 0; + /** Get the name of primitive. */ + virtual const char* name() const = 0; /** Equivalence check defaults to false unless overridden by the primitive */ virtual bool is_equivalent(const Primitive& other) const { @@ -160,7 +160,7 @@ class Abs : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Abs) + DEFINE_NAME(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -174,7 +174,7 @@ class Add : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Add) + DEFINE_NAME(Add) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -189,7 +189,7 @@ class AddMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(AddMM) + DEFINE_NAME(AddMM) bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -209,7 +209,7 @@ class Arange : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Arange) + DEFINE_NAME(Arange) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::tuple state() const { @@ -231,7 +231,7 @@ class ArcCos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCos) + DEFINE_NAME(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -245,7 +245,7 @@ class ArcCosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcCosh) + DEFINE_NAME(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -259,7 +259,7 @@ class ArcSin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSin) + DEFINE_NAME(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -273,7 +273,7 @@ class ArcSinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcSinh) + DEFINE_NAME(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -287,7 +287,7 @@ class ArcTan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan) + DEFINE_NAME(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -301,7 +301,7 @@ class ArcTan2 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTan2) + DEFINE_NAME(ArcTan2) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -315,7 +315,7 @@ class ArcTanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArcTanh) + DEFINE_NAME(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -330,7 +330,7 @@ class ArgPartition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgPartition) + DEFINE_NAME(ArgPartition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; std::pair state() const { @@ -357,7 +357,7 @@ class ArgReduce : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgReduce) + DEFINE_NAME(ArgReduce) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair state() const { @@ -379,7 +379,7 @@ class ArgSort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ArgSort) + DEFINE_NAME(ArgSort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; int state() const { @@ -400,7 +400,7 @@ class AsType : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(AsType) + DEFINE_NAME(AsType) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; Dtype state() const { @@ -423,7 +423,7 @@ class AsStrided : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_GRADS() - DEFINE_PRINT(AsStrided) + DEFINE_NAME(AsStrided) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(shape_, strides_, offset_); @@ -449,8 +449,24 @@ class BitwiseBinary : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() + + const char* name() const override { + switch (op_) { + case BitwiseBinary::And: + return "BitwiseAnd"; + case BitwiseBinary::Or: + return "BitwiseOr"; + case BitwiseBinary::Xor: + return "BitwiseXor"; + case BitwiseBinary::LeftShift: + return "LeftShift"; + case BitwiseBinary::RightShift: + return "RightShift"; + } + return ""; + } + bool is_equivalent(const Primitive& other) const override; - void print(std::ostream& os) override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { return op_; @@ -468,7 +484,7 @@ class BitwiseInvert : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(BitwiseInvert) + DEFINE_NAME(BitwiseInvert) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -487,7 +503,7 @@ class BlockMaskedMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(BlockMaskedMM) + DEFINE_NAME(BlockMaskedMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return block_size_; @@ -516,7 +532,7 @@ class GatherMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(GatherMM) + DEFINE_NAME(GatherMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_pair(left_sorted_, right_sorted_); @@ -534,7 +550,7 @@ class SegmentedMM : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(SegmentedMM) + DEFINE_NAME(SegmentedMM) }; class BroadcastAxes : public UnaryPrimitive { @@ -547,7 +563,7 @@ class BroadcastAxes : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(BroadcastAxes) + DEFINE_NAME(BroadcastAxes) bool is_equivalent(const Primitive& other) const override; static Shape output_shape( const std::vector& inputs, @@ -572,7 +588,7 @@ class Broadcast : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Broadcast) + DEFINE_NAME(Broadcast) static Shape output_shape(const std::vector& inputs); std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -595,7 +611,7 @@ class Ceil : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Ceil) + DEFINE_NAME(Ceil) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -625,8 +641,8 @@ class Compiled : public Primitive { DEFINE_VMAP() DEFINE_GRADS() + const char* name() const override; std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override; bool is_equivalent(const Primitive& other) const override; std::string lib_name() const { @@ -640,6 +656,7 @@ class Compiled : public Primitive { const std::unordered_set constant_ids_; const std::function is_constant_; + mutable std::string name_; std::string kernel_lib_; }; @@ -653,7 +670,7 @@ class Concatenate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Concatenate) + DEFINE_NAME(Concatenate) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -672,7 +689,7 @@ class Conjugate : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(Conjugate) + DEFINE_NAME(Conjugate) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -687,7 +704,7 @@ class Contiguous : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Contiguous) + DEFINE_NAME(Contiguous) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -726,7 +743,7 @@ class Convolution : public UnaryPrimitive { const std::vector& outputs) override; DEFINE_VMAP() - DEFINE_PRINT(Convolution) + DEFINE_NAME(Convolution) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( @@ -758,7 +775,7 @@ class Copy : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Copy) + DEFINE_NAME(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -775,7 +792,7 @@ class Cos : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cos) + DEFINE_NAME(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -789,7 +806,7 @@ class Cosh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Cosh) + DEFINE_NAME(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -823,7 +840,7 @@ class CustomTransforms : public Primitive { DEFINE_GRADS(); DEFINE_VMAP(); - DEFINE_PRINT(CustomTransforms); + DEFINE_NAME(CustomTransforms); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -861,7 +878,7 @@ class Depends : public Primitive { const std::vector& argnums, const std::vector& outputs) override; - DEFINE_PRINT(Depends); + DEFINE_NAME(Depends); private: void eval(const std::vector& inputs, std::vector& outputs); @@ -876,7 +893,7 @@ class Divide : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Divide) + DEFINE_NAME(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -892,7 +909,7 @@ class DivMod : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DivMod) + DEFINE_NAME(DivMod) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override { return std::vector{inputs[0].shape(), inputs[0].shape()}; @@ -908,7 +925,7 @@ class Select : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Select) + DEFINE_NAME(Select) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -922,7 +939,7 @@ class Remainder : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Remainder) + DEFINE_NAME(Remainder) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -940,11 +957,11 @@ class Equal : public UnaryPrimitive { DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() - void print(std::ostream& os) override { + const char* name() const override { if (equal_nan_) { - os << "NaNEqual"; + return "NaNEqual"; } else { - os << "Equal"; + return "Equal"; } } auto state() const { @@ -964,7 +981,7 @@ class Erf : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Erf) + DEFINE_NAME(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -978,7 +995,7 @@ class ErfInv : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ErfInv) + DEFINE_NAME(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -992,7 +1009,7 @@ class Exp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Exp) + DEFINE_NAME(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1006,7 +1023,7 @@ class Expm1 : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Expm1) + DEFINE_NAME(Expm1) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1020,7 +1037,7 @@ class ExpandDims : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(ExpandDims) + DEFINE_NAME(ExpandDims) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1049,7 +1066,7 @@ class FFT : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(FFT) + DEFINE_NAME(FFT) bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1072,7 +1089,7 @@ class Flatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Flatten) + DEFINE_NAME(Flatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -1096,7 +1113,7 @@ class Floor : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Floor) + DEFINE_NAME(Floor) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1110,7 +1127,7 @@ class Full : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Full) + DEFINE_NAME(Full) DEFINE_DEFAULT_IS_EQUIVALENT() }; @@ -1126,7 +1143,7 @@ class Gather : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Gather) + DEFINE_NAME(Gather) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::pair, std::vector> state() const { @@ -1148,7 +1165,7 @@ class GatherAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherAxis) + DEFINE_NAME(GatherAxis) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1168,7 +1185,7 @@ class Greater : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Greater) + DEFINE_NAME(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1182,7 +1199,7 @@ class GreaterEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GreaterEqual) + DEFINE_NAME(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1197,7 +1214,7 @@ class Hadamard : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Hadamard) + DEFINE_NAME(Hadamard) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -1218,7 +1235,7 @@ class Imag : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Imag) + DEFINE_NAME(Imag) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1232,7 +1249,7 @@ class Less : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Less) + DEFINE_NAME(Less) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1246,7 +1263,7 @@ class LessEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LessEqual) + DEFINE_NAME(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1266,7 +1283,7 @@ class Load : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Load) + DEFINE_NAME(Load) private: std::shared_ptr reader_; @@ -1293,18 +1310,16 @@ class Log : public UnaryPrimitive { return base_; }; - void print(std::ostream& os) override { + const char* name() const override { switch (base_) { case e: - os << "Log"; - break; + return "Log"; case two: - os << "Log2"; - break; + return "Log2"; case ten: - os << "Log10"; - break; + return "Log10"; } + return ""; } private: @@ -1320,7 +1335,7 @@ class Log1p : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Log1p) + DEFINE_NAME(Log1p) DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1333,7 +1348,7 @@ class LogicalNot : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalNot) + DEFINE_NAME(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1347,7 +1362,7 @@ class LogicalAnd : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalAnd) + DEFINE_NAME(LogicalAnd) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1361,7 +1376,7 @@ class LogicalOr : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogicalOr) + DEFINE_NAME(LogicalOr) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1375,7 +1390,7 @@ class LogAddExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogAddExp) + DEFINE_NAME(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1389,7 +1404,7 @@ class LogSumExp : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(LogSumExp) + DEFINE_NAME(LogSumExp) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1403,7 +1418,7 @@ class Matmul : public UnaryPrimitive { DEFINE_GRADS() DEFINE_VMAP() - DEFINE_PRINT(Matmul) + DEFINE_NAME(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() std::vector output_shapes(const std::vector& inputs) override; }; @@ -1417,7 +1432,7 @@ class Maximum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Maximum) + DEFINE_NAME(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1431,7 +1446,7 @@ class Minimum : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Minimum) + DEFINE_NAME(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1445,7 +1460,7 @@ class Multiply : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Multiply) + DEFINE_NAME(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1459,7 +1474,7 @@ class Negative : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Negative) + DEFINE_NAME(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1473,7 +1488,7 @@ class NotEqual : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(NotEqual) + DEFINE_NAME(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1494,7 +1509,7 @@ class NumberOfElements : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(NumberOfElements) + DEFINE_NAME(NumberOfElements) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override { return {{}}; @@ -1528,7 +1543,7 @@ class Pad : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Pad) + DEFINE_NAME(Pad) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(axes_, low_pad_size_, high_pad_size_); @@ -1550,7 +1565,7 @@ class Partition : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Partition) + DEFINE_NAME(Partition) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -1571,7 +1586,7 @@ class Power : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Power) + DEFINE_NAME(Power) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1593,7 +1608,7 @@ class QuantizedMatmul : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(QuantizedMatmul) + DEFINE_NAME(QuantizedMatmul) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -1627,7 +1642,7 @@ class GatherQMM : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(GatherQMM) + DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( @@ -1651,7 +1666,7 @@ class RandomBits : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(RandomBits) + DEFINE_NAME(RandomBits) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {shape_, width_}; @@ -1671,7 +1686,7 @@ class Real : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Real) + DEFINE_NAME(Real) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1686,7 +1701,7 @@ class Reshape : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Reshape) + DEFINE_NAME(Reshape) bool is_equivalent(const Primitive& other) const override; std::vector state() const { return shape_; @@ -1721,28 +1736,24 @@ class Reduce : public UnaryPrimitive { std::vector output_shapes(const std::vector& inputs) override; - void print(std::ostream& os) override { + const char* name() const override { switch (reduce_type_) { case And: - os << "And"; - break; + return "And"; case Or: - os << "Or"; - break; + return "Or"; case Sum: - os << "Sum"; - break; + return "Sum"; case Prod: - os << "Prod"; - break; + return "Prod"; case Min: - os << "Min"; - break; + return "Min"; case Max: - os << "Max"; - break; + return "Max"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1762,7 +1773,7 @@ class Round : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Round) + DEFINE_NAME(Round) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1789,26 +1800,22 @@ class Scan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Cum"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << "Sum"; - break; + return "CumSum"; case Prod: - os << "Prod"; - break; + return "CumProd"; case Min: - os << "Min"; - break; + return "CumMin"; case Max: - os << "Max"; - break; + return "CumMax"; case LogAddExp: - os << "Logaddexp"; - break; + return "CumLogAddExp"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_); @@ -1837,25 +1844,22 @@ class Scatter : public UnaryPrimitive { DEFINE_VMAP(); DEFINE_GRADS(); - void print(std::ostream& os) override { - os << "Scatter"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterSum"; case Prod: - os << " Prod"; - break; + return "ScatterProd"; case Min: - os << " Min"; - break; + return "ScatterMin"; case Max: - os << " Max"; - break; + return "ScatterMax"; case None: - break; + return "Scatter"; } + return ""; } + bool is_equivalent(const Primitive& other) const override; std::pair> state() const { return {reduce_type_, axes_}; @@ -1879,15 +1883,14 @@ class ScatterAxis : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - void print(std::ostream& os) override { - os << "ScatterAxis"; + const char* name() const override { switch (reduce_type_) { case Sum: - os << " Sum"; - break; + return "ScatterAxisSum"; case None: - break; + return "ScatterAxis"; } + return ""; } bool is_equivalent(const Primitive& other) const override; @@ -1910,7 +1913,7 @@ class Sigmoid : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sigmoid) + DEFINE_NAME(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1924,7 +1927,7 @@ class Sign : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sign) + DEFINE_NAME(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1938,7 +1941,7 @@ class Sin : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sin) + DEFINE_NAME(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1952,7 +1955,7 @@ class Sinh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sinh) + DEFINE_NAME(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -1974,7 +1977,7 @@ class Slice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Slice) + DEFINE_NAME(Slice) bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(start_indices_, end_indices_, strides_); @@ -2003,7 +2006,7 @@ class SliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(SliceUpdate) + DEFINE_NAME(SliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2028,7 +2031,7 @@ class DynamicSlice : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSlice) + DEFINE_NAME(DynamicSlice) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { @@ -2050,7 +2053,7 @@ class DynamicSliceUpdate : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(DynamicSliceUpdate) + DEFINE_NAME(DynamicSliceUpdate) bool is_equivalent(const Primitive& other) const override; DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { @@ -2071,7 +2074,7 @@ class Softmax : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Softmax) + DEFINE_NAME(Softmax) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; @@ -2093,7 +2096,7 @@ class Sort : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Sort) + DEFINE_NAME(Sort) DEFINE_INPUT_OUTPUT_SHAPE() bool is_equivalent(const Primitive& other) const override; auto state() const { @@ -2116,7 +2119,7 @@ class Split : public Primitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Split) + DEFINE_NAME(Split) bool is_equivalent(const Primitive& other) const override; std::pair, int> state() const { return {indices_, axis_}; @@ -2138,7 +2141,7 @@ class Square : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Square) + DEFINE_NAME(Square) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2159,11 +2162,11 @@ class Sqrt : public UnaryPrimitive { return recip_; } - void print(std::ostream& os) override { + const char* name() const override { if (recip_) { - os << "Rsqrt"; + return "Rsqrt"; } else { - os << "Sqrt"; + return "Sqrt"; } } @@ -2179,7 +2182,7 @@ class StopGradient : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - DEFINE_PRINT(StopGradient) + DEFINE_NAME(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() @@ -2196,7 +2199,7 @@ class Subtract : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Subtract) + DEFINE_NAME(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2211,7 +2214,7 @@ class Squeeze : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Squeeze) + DEFINE_NAME(Squeeze) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2235,7 +2238,7 @@ class Tan : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tan) + DEFINE_NAME(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2249,7 +2252,7 @@ class Tanh : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Tanh) + DEFINE_NAME(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() }; @@ -2264,7 +2267,7 @@ class Unflatten : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Unflatten) + DEFINE_NAME(Unflatten) std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; @@ -2288,7 +2291,7 @@ class View : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - void print(std::ostream& os) override; + const char* name() const override; bool is_equivalent(const Primitive& other) const override; auto state() const { return dtype_; @@ -2296,6 +2299,7 @@ class View : public UnaryPrimitive { private: Dtype dtype_; + mutable std::string name_; }; class Transpose : public UnaryPrimitive { @@ -2308,7 +2312,7 @@ class Transpose : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() - DEFINE_PRINT(Transpose) + DEFINE_NAME(Transpose) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; std::vector state() const { @@ -2331,7 +2335,7 @@ class QRF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(QRF) + DEFINE_NAME(QRF) }; /* SVD primitive. */ @@ -2346,7 +2350,7 @@ class SVD : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(SVD) + DEFINE_NAME(SVD) auto state() const { return compute_uv_; } @@ -2365,7 +2369,7 @@ class Inverse : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& output) override; DEFINE_VMAP() - DEFINE_PRINT(Inverse) + DEFINE_NAME(Inverse) auto state() const { return std::make_pair(tri_, upper_); } @@ -2387,7 +2391,7 @@ class Cholesky : public UnaryPrimitive { } DEFINE_VMAP() - DEFINE_PRINT(Cholesky) + DEFINE_NAME(Cholesky) private: bool upper_; @@ -2403,7 +2407,7 @@ class Eig : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(Eig) + DEFINE_NAME(Eig) std::vector output_shapes(const std::vector& inputs) override; @@ -2428,7 +2432,7 @@ class Eigh : public Primitive { override; DEFINE_VMAP() - DEFINE_PRINT(Eigh) + DEFINE_NAME(Eigh) std::vector output_shapes(const std::vector& inputs) override; @@ -2451,7 +2455,7 @@ class LUF : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_PRINT(LUF) + DEFINE_NAME(LUF) }; } // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 2d9942eda..d9e227ea3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -33,7 +33,7 @@ class Synchronizer : public Primitive { void eval_cpu(const std::vector&, std::vector&) override {} void eval_gpu(const std::vector&, std::vector&) override {} - DEFINE_PRINT(Synchronize); + DEFINE_NAME(Synchronize); }; // Initialize the static tracing members from transforms_impl.h diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index cc8e79db6..634abaef4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -514,7 +514,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigh", - [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { + [](const mx::array& a, const std::string& UPLO, mx::StreamOrDevice s) { auto result = mx::linalg::eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 54642409c..3b2f4a53a 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -14,7 +14,7 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -bool DEPRECATE(const std::string& old_fn, const std::string new_fn) { +bool DEPRECATE(const char* old_fn, const char* new_fn) { std::cerr << old_fn << " is deprecated and will be removed in a future " << "version. Use " << new_fn << " instead." << std::endl; return true; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d047f64cb..9703bbd2d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3076,7 +3076,7 @@ void init_ops(nb::module_& m) { std::tuple, std::pair, std::vector>>& pad_width, - const std::string mode, + const std::string& mode, const ScalarOrArray& constant_value, mx::StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) { From e569803d7cb0e0a228bcf4bb02c72bb417955209 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Jul 2025 15:13:56 -0700 Subject: [PATCH 150/156] update linux build (#2370) --- .circleci/config.yml | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index be5f7aac5..01d432bfe 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -73,9 +73,9 @@ jobs: git push -f origin gh-pages linux_build_and_test: - docker: - - image: cimg/python:3.9 - + machine: + image: ubuntu-2204:current + resource_class: large steps: - checkout - run: @@ -87,19 +87,17 @@ jobs: - run: name: Install dependencies command: | - pip install --upgrade cmake - pip install nanobind==2.4.0 - pip install numpy + export DEBIAN_FRONTEND=noninteractive + export NEEDRESTART_MODE=a sudo apt-get update - sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get upgrade -y + pip install --upgrade cmake + sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev - run: name: Install Python package command: | - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - python3 setup.py build_ext --inplace - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ - python3 setup.py develop + pip install -e ".[dev]" - run: name: Generate package stubs command: | @@ -109,13 +107,13 @@ jobs: - run: name: Run Python tests command: | - python3 -m unittest discover python/tests -v + python -m unittest discover python/tests -v mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py - run: name: Build CPP only command: | - mkdir -p build && cd build + mkdir -p build && cd build cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG make -j `nproc` - run: From e7d2ebadd29e7a76379dba048a419f1e17e82b32 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Jul 2025 15:45:44 -0700 Subject: [PATCH 151/156] [CUDA] Affine quantize (#2354) * affine quantize and dequantize kernels * format * fix * format --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/quantized.cu | 383 ++++++++++++++++++++++++++++++++ python/tests/cuda_skip.py | 1 - 4 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/quantized.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 29f2eeab6..9f236b4ea 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -42,6 +42,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 3a3f8ff54..a7f4e8f66 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU_MULTI(Eigh) namespace fast { NO_GPU(ScaledDotProductAttention) -NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu new file mode 100644 index 000000000..12a1f6fe4 --- /dev/null +++ b/mlx/backend/cuda/quantized.cu @@ -0,0 +1,383 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include + +namespace mlx::core { +namespace cu { + +namespace cg = cooperative_groups; + +template +inline constexpr __device__ short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr __device__ short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +__global__ void +affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + constexpr float eps = 1e-7; + constexpr int simd_size = WARP_SIZE; + constexpr float n_bins = (1 << bits) - 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t in_index = offset * values_per_reduce; + if (in_index >= size) { + return; + } + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; + + float w_thread[values_per_reduce]; + float w_min = Limits::max(); + float w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + float val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + cg::greater max_op; + cg::less min_op; + auto warp = cg::tiled_partition(cg::this_thread_block()); + + w_min = cg::reduce(warp, w_min, min_op); + w_max = cg::reduce(warp, w_max, max_op); + + float scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + float bias = at_zero ? 0 : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); + } + + using OutType = std::conditional_t; + OutType output = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output |= val << (bits * (i % pack_factor)); + } + + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = warp.shfl_down(val, j); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); + } + } + } + if constexpr (bits == 3 || bits == 6) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if constexpr (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } + } else { + if constexpr (writes_per_reduce > 0) { + if (out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } + } + } +} + +template +__global__ void affine_dequantize( + const uint8_t* w, + const T* scales, + const T* biases, + T* out, + size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + out += oindex; + + if constexpr (bits == 3) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x7) * scale + bias; + out[1] = static_cast((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (static_cast((w[0] & 0xc0) >> 6) + + static_cast((w[1] & 0x1) << 2)) * + scale + + bias; + out[3] = static_cast((w[1] & 0xe) >> 1) * scale + bias; + out[4] = static_cast((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0x3) << 1)) * + scale + + bias; + out[6] = static_cast((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = static_cast((w[2] & 0xe0) >> 5) * scale + bias; + } else if constexpr (bits == 5) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x1f) * scale + bias; + out[1] = (static_cast((w[0] & 0xe0) >> 5) + + static_cast((w[1] & 0x3) << 3)) * + scale + + bias; + out[2] = static_cast((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0xf) << 1)) * + scale + + bias; + out[4] = (static_cast((w[2] & 0xf0) >> 4) + + static_cast((w[3] & 0x1) << 4)) * + scale + + bias; + out[5] = static_cast((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (static_cast((w[3] & 0xc0) >> 6) + + static_cast((w[4] & 0x7) << 2)) * + scale + + bias; + out[7] = static_cast((w[4] & 0xf8) >> 3) * scale + bias; + } else if constexpr (bits == 6) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x3f) * scale + bias; + out[1] = (static_cast((w[0] >> 6) & 0x03) + + static_cast((w[1] & 0x0f) << 2)) * + scale + + bias; + out[2] = (static_cast((w[1] >> 4) & 0x0f) + + static_cast((w[2] & 0x03) << 4)) * + scale + + bias; + out[3] = static_cast((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * static_cast(d) + bias; + } + } +} + +} // namespace cu +namespace { + +inline array ensure_row_contiguous( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +} // namespace + +template +void dispatch_groups(int group_size, F&& f) { + switch (group_size) { + case 32: + f(std::integral_constant{}); + break; + case 64: + f(std::integral_constant{}); + break; + case 128: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bits(int bits, F&& f) { + switch (bits) { + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 5: + f(std::integral_constant{}); + break; + case 6: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + } +} + +void fast::AffineQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& d = cu::device(s.device); + auto& enc = d.get_command_encoder(s); + + auto w = ensure_row_contiguous(w_pre, enc, s); + enc.set_input_array(w); + if (dequantize_) { + auto scales = ensure_row_contiguous(inputs[1], enc, s); + auto biases = ensure_row_contiguous(inputs[2], enc, s); + enc.set_input_array(scales); + enc.set_input_array(biases); + enc.set_output_array(out); + } else { + auto& scales = outputs[1]; + auto& biases = outputs[2]; + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); + enc.set_output_array(out); + enc.set_output_array(scales); + enc.set_output_array(biases); + } + + auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype(); + + // Treat uint32 as uint8 in kernel + int uint8_per_uint32 = 4; + int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 + : bits_ == 6 ? 4 + : 8 / bits_; + int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE; + size_t size = + dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; + + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + + if (dequantize_) { + grid_shape.back() *= uint8_per_uint32; + } else { + grid_shape.back() /= per_thread; + } + + dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) { + dispatch_groups(group_size_, [&](auto group_size) { + dispatch_bits(bits_, [&](auto bits) { + using DataType = cuda_type_t; + if (dequantize_) { + auto kernel = cu::affine_dequantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + inputs[1].data(), + inputs[2].data(), + out.data(), + out.size()); + } else { + auto kernel = cu::affine_quantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + out.data(), + outputs[1].data(), + outputs[2].data(), + w.size()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 005c612ff..7c9ff84ce 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -83,7 +83,6 @@ cuda_skip = { "TestQuantized.test_qmm_shapes", "TestQuantized.test_qmm_vjp", "TestQuantized.test_qmv", - "TestQuantized.test_quantize_dequantize", "TestQuantized.test_qvm", "TestQuantized.test_qvm_splitk", "TestQuantized.test_small_matrix", From 49114f28aba473af6acf27633f7e8977e292273b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Jul 2025 17:16:18 -0700 Subject: [PATCH 152/156] fix flaky test (#2371) --- python/tests/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8521d8f80..bbea9ad8e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1478,7 +1478,7 @@ class TestOps(mlx_tests.MLXTestCase): r_mlx = mlxop(y) mx.eval(r_mlx) - self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True)) x = np.random.rand(9, 12, 18) xi = np.random.rand(9, 12, 18) From f0a0b077a042c47099057ae0204c362f570d5650 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Jul 2025 17:17:33 -0700 Subject: [PATCH 153/156] Install linux with mlx[cuda] and mlx[cpu] (#2356) * install linux with mlx[cuda] and mlx[cpu] * temp for testing * cleanup circle, fix cuda repair * update circle * update circle * decouple python bindings from core libraries --- .circleci/config.yml | 225 +++++++++++--------------------- CMakeLists.txt | 2 - docs/src/install.rst | 17 +-- pyproject.toml | 2 +- python/scripts/repair_cuda.sh | 22 ++-- python/scripts/repair_linux.sh | 19 +++ python/scripts/repair_record.py | 33 +++++ setup.py | 156 ++++++++++++++++------ 8 files changed, 264 insertions(+), 212 deletions(-) create mode 100644 python/scripts/repair_linux.sh create mode 100644 python/scripts/repair_record.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 01d432bfe..3d24cb432 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,18 +7,6 @@ parameters: nightly_build: type: boolean default: false - weekly_build: - type: boolean - default: false - test_release: - type: boolean - default: false - linux_release: - type: boolean - default: false - cuda_release: - type: boolean - default: false jobs: build_documentation: @@ -282,7 +270,17 @@ jobs: name: Build Python package command: | source env/bin/activate - << parameters.build_env >> python -m build -w + << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w + - when: + condition: + equal: ["3.9", << parameters.python_version >>] + steps: + - run: + name: Build common package + command: | + source env/bin/activate + python setup.py clean --all + << parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w - when: condition: << parameters.build_env >> steps: @@ -299,59 +297,70 @@ jobs: python_version: type: string default: "3.9" - extra_env: + build_env: type: string - default: "DEV_RELEASE=1" - docker: - - image: ubuntu:20.04 + default: "" + machine: + image: ubuntu-2204:current + resource_class: large steps: - checkout - run: name: Build wheel command: | PYTHON=python<< parameters.python_version >> - apt-get update - apt-get upgrade -y - DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata - apt-get install -y apt-utils - apt-get install -y software-properties-common - add-apt-repository -y ppa:deadsnakes/ppa - apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full - apt-get install -y libblas-dev liblapack-dev liblapacke-dev - apt-get install -y build-essential git + export DEBIAN_FRONTEND=noninteractive + export NEEDRESTART_MODE=a + sudo apt-get update + sudo apt-get upgrade -y + TZ=Etc/UTC sudo apt-get -y install tzdata + sudo apt-get install -y apt-utils + sudo apt-get install -y software-properties-common + sudo add-apt-repository -y ppa:deadsnakes/ppa + sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full + sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install -y build-essential git $PYTHON -m venv env source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install nanobind==2.4.0 - pip install --upgrade setuptools - pip install numpy pip install auditwheel pip install patchelf pip install build pip install twine - << parameters.extra_env >> pip install . -v + << parameters.build_env >> pip install ".[dev]" -v pip install typing_extensions python setup.py generate_stubs - << parameters.extra_env >> python -m build --wheel - auditwheel show dist/* - auditwheel repair dist/* --plat manylinux_2_31_x86_64 - - run: - name: Upload package - command: | - source env/bin/activate - twine upload wheelhouse/* + MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w + bash python/scripts/repair_linux.sh + - when: + condition: + equal: ["3.9", << parameters.python_version >>] + steps: + - run: + name: Build common package + command: | + source env/bin/activate + python setup.py clean --all + << parameters.build_env >> MLX_BUILD_STAGE=2 \ + python -m build -w + auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64 + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload packages + command: | + source env/bin/activate + twine upload wheelhouse/*.whl - store_artifacts: path: wheelhouse/ build_cuda_release: parameters: - python_version: + build_env: type: string - default: "3.9" - extra_env: - type: string - default: "DEV_RELEASE=1" + default: "" machine: image: linux-cuda-12:default resource_class: gpu.nvidia.small.gen2 @@ -362,25 +371,25 @@ jobs: command: | sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev + sudo apt-get install zip python -m venv env source env/bin/activate pip install auditwheel pip install patchelf pip install build pip install twine - << parameters.extra_env >> \ + << parameters.build_env >> MLX_BUILD_STAGE=2 \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ - pip install ".[dev]" -v - python setup.py generate_stubs - << parameters.extra_env >> \ - CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ - python -m build --wheel + python -m build -w bash python/scripts/repair_cuda.sh - - run: - name: Upload package - command: | - source env/bin/activate - twine upload wheelhouse/*.whl + - when: + condition: << parameters.build_env >> + steps: + - run: + name: Upload package + command: | + source env/bin/activate + twine upload wheelhouse/*.whl - store_artifacts: path: wheelhouse/ @@ -392,8 +401,6 @@ workflows: pattern: "^(?!pull/)[-\\w]+$" value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - - not: << pipeline.parameters.weekly_build >> - - not: << pipeline.parameters.test_release >> jobs: - mac_build_and_test: matrix: @@ -407,8 +414,6 @@ workflows: when: and: - not: << pipeline.parameters.nightly_build >> - - not: << pipeline.parameters.weekly_build >> - - not: << pipeline.parameters.test_release >> jobs: - build_release: filters: @@ -499,7 +504,16 @@ workflows: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - extra_env: ["PYPI_RELEASE=1"] + build_env: ["PYPI_RELEASE=1"] + - build_cuda_release: + filters: + tags: + only: /^v.*/ + branches: + ignore: /.*/ + matrix: + parameters: + build_env: ["PYPI_RELEASE=1"] prb: when: @@ -578,99 +592,8 @@ workflows: - macosx_deployment_target: "15.0" xcode_version: "15.0.0" python_version: "3.13" - weekly_build: - when: - and: - - equal: [ main, << pipeline.git.branch >> ] - - << pipeline.parameters.weekly_build >> - jobs: - - build_release: - matrix: - parameters: - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - macosx_deployment_target: ["13.5", "14.0", "15.0"] - build_env: ["DEV_RELEASE=1"] - xcode_version: ["16.2.0", "15.0.0"] - exclude: - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "13.5" - xcode_version: "16.2.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "14.0" - xcode_version: "15.0.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.9" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.10" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.11" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.12" - build_env: "DEV_RELEASE=1" - - macosx_deployment_target: "15.0" - xcode_version: "15.0.0" - python_version: "3.13" - build_env: "DEV_RELEASE=1" - linux_test_release: - when: - and: - - equal: [ main, << pipeline.git.branch >> ] - - << pipeline.parameters.linux_release >> - jobs: - build_linux_release: matrix: parameters: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - extra_env: ["PYPI_RELEASE=1"] - cuda_test_release: - when: - and: - - equal: [ main, << pipeline.git.branch >> ] - - << pipeline.parameters.cuda_release >> - jobs: - - build_cuda_release: - matrix: - parameters: - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - extra_env: ["PYPI_RELEASE=1"] + - build_cuda_release diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bf8d2d3e..9e67e4bf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,10 +64,8 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") message(WARNING "Building for x86_64 arch is not officially supported.") endif() endif() - else() set(MLX_BUILD_METAL OFF) - message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") endif() # ----------------------------- Lib ----------------------------- diff --git a/docs/src/install.rst b/docs/src/install.rst index a50b6a71d..70491ac64 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -23,13 +23,6 @@ To install from PyPI you must meet the following requirements: MLX is only available on devices running macOS >= 13.5 It is highly recommended to use macOS 14 (Sonoma) - -MLX is also available on conda-forge. To install MLX with conda do: - -.. code-block:: shell - - conda install conda-forge::mlx - CUDA ^^^^ @@ -38,8 +31,16 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run: .. code-block:: shell - pip install mlx-cuda + pip install "mlx[cuda]" +CPU-only (Linux) +^^^^^^^^^^^^^^^^ + +For a CPU-only version of MLX that runs on Linux use: + +.. code-block:: shell + + pip install "mlx[cpu]" Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index ad0d2e328..6fcd5d16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=42", + "setuptools>=80", "nanobind==2.4.0", "cmake>=3.25", ] diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh index 21e6a977a..ec0a89930 100644 --- a/python/scripts/repair_cuda.sh +++ b/python/scripts/repair_cuda.sh @@ -1,17 +1,23 @@ #!/bin/bash auditwheel repair dist/* \ - --plat manylinux_2_35_x86_64 \ + --plat manylinux_2_39_x86_64 \ --exclude libcublas* \ - --exclude libnvrtc* + --exclude libnvrtc* \ + -w wheel_tmp -cd wheelhouse + +mkdir wheelhouse +cd wheel_tmp repaired_wheel=$(find . -name "*.whl" -print -quit) unzip -q "${repaired_wheel}" -core_so=$(find mlx -name "core*.so" -print -quit) -rpath=$(patchelf --print-rpath "${core_so}") -rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib -patchelf --force-rpath --set-rpath "$rpath" "$core_so" +rm "${repaired_wheel}" +mlx_so="mlx/lib/libmlx.so" +rpath=$(patchelf --print-rpath "${mlx_so}") +base="\$ORIGIN/../../nvidia" +rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib +patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" +python ../python/scripts/repair_record.py ${mlx_so} # Re-zip the repaired wheel -zip -r -q "${repaired_wheel}" . +zip -r -q "../wheelhouse/${repaired_wheel}" . diff --git a/python/scripts/repair_linux.sh b/python/scripts/repair_linux.sh new file mode 100644 index 000000000..82cf49060 --- /dev/null +++ b/python/scripts/repair_linux.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +auditwheel repair dist/* \ + --plat manylinux_2_35_x86_64 \ + --exclude libmlx* \ + -w wheel_tmp + +mkdir wheelhouse +cd wheel_tmp +repaired_wheel=$(find . -name "*.whl" -print -quit) +unzip -q "${repaired_wheel}" +rm "${repaired_wheel}" +core_so=$(find mlx -name "core*.so" -print -quit) +rpath="\$ORIGIN/lib" +patchelf --force-rpath --set-rpath "$rpath" "$core_so" +python ../python/scripts/repair_record.py ${core_so} + +# Re-zip the repaired wheel +zip -r -q "../wheelhouse/${repaired_wheel}" . diff --git a/python/scripts/repair_record.py b/python/scripts/repair_record.py new file mode 100644 index 000000000..1738fd5ad --- /dev/null +++ b/python/scripts/repair_record.py @@ -0,0 +1,33 @@ +import base64 +import glob +import hashlib +import sys + +filename = sys.argv[1] + + +# Compute the new hash and size +def urlsafe_b64encode(data: bytes) -> bytes: + return base64.urlsafe_b64encode(data).rstrip(b"=") + + +hasher = hashlib.sha256() +with open(filename, "rb") as f: + data = f.read() + hasher.update(data) +hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii") +size = len(data) + +# Update the record file +record_file = glob.glob("*/RECORD")[0] +with open(record_file, "r") as f: + lines = [l.split(",") for l in f.readlines()] + +for l in lines: + if filename == l[0]: + l[1] = hash_str + l[2] = f"{size}\n" + +with open(record_file, "w") as f: + for l in lines: + f.write(",".join(l)) diff --git a/setup.py b/setup.py index 770718e25..6cc4015c3 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,12 @@ import os import platform import re import subprocess +from functools import partial from pathlib import Path from subprocess import run -from setuptools import Command, Extension, find_namespace_packages, setup +from setuptools import Command, Extension, setup +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext @@ -41,6 +43,9 @@ def get_version(): return version +build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0)) + + # A CMakeExtension needs a sourcedir instead of a file list. # The name must be the _single_ output extension from the CMake build. # If you need multiple extensions, see scikit-build. @@ -59,13 +64,22 @@ class CMakeBuild(build_ext): debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug cfg = "Debug" if debug else "Release" - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + build_python = "ON" + install_prefix = f"{extdir}{os.sep}" + if build_stage == 1: + # Don't include MLX libraries in the wheel + install_prefix = f"{build_temp}" + elif build_stage == 2: + # Don't include Python bindings in the wheel + build_python = "OFF" cmake_args = [ - f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}", + f"-DCMAKE_INSTALL_PREFIX={install_prefix}", f"-DCMAKE_BUILD_TYPE={cfg}", - "-DMLX_BUILD_PYTHON_BINDINGS=ON", + f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}", "-DMLX_BUILD_TESTS=OFF", "-DMLX_BUILD_BENCHMARKS=OFF", "-DMLX_BUILD_EXAMPLES=OFF", @@ -99,10 +113,6 @@ class CMakeBuild(build_ext): if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: build_args += [f"-j{os.cpu_count()}"] - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - subprocess.run( ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ) @@ -158,26 +168,40 @@ class GenerateStubs(Command): subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) +class MLXBdistWheel(bdist_wheel): + def get_tag(self) -> tuple[str, str, str]: + impl, abi, plat_name = super().get_tag() + if build_stage == 2: + impl = self.python_tag + abi = "none" + return (impl, abi, plat_name) + + # Read the content of README.md with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: long_description = f.read() -# The information here can also be placed in setup.cfg - better separation of -# logic and declaration, and simpler if you include description/version in a file. + if __name__ == "__main__": - packages = find_namespace_packages( - where="python", exclude=["src", "tests", "tests.*"] - ) package_dir = {"": "python"} - package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} - install_requires = [] + packages = [ + "mlx", + "mlx.nn", + "mlx.nn.layers", + "mlx.optimizers", + ] + + build_macos = platform.system() == "Darwin" build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") + + install_requires = [] if build_cuda: install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] + version = get_version() - setup( - name="mlx-cuda" if build_cuda else "mlx", - version=get_version(), + _setup = partial( + setup, + version=version, author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", @@ -185,29 +209,77 @@ if __name__ == "__main__": long_description_content_type="text/markdown", license="MIT", url="https://github.com/ml-explore/mlx", - packages=packages, - package_dir=package_dir, - package_data=package_data, include_package_data=True, - install_requires=install_requires, - extras_require={ - "dev": [ - "nanobind==2.4.0", - "numpy", - "pre-commit", - "setuptools>=42", - "torch", - "typing_extensions", - ], - }, - entry_points={ - "console_scripts": [ - "mlx.launch = mlx.distributed_run:main", - "mlx.distributed_config = mlx.distributed_run:distributed_config", - ] - }, - ext_modules=[CMakeExtension("mlx.core")], - cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, + package_dir=package_dir, zip_safe=False, python_requires=">=3.9", + ext_modules=[CMakeExtension("mlx.core")], + cmdclass={ + "build_ext": CMakeBuild, + "generate_stubs": GenerateStubs, + "bdist_wheel": MLXBdistWheel, + }, ) + + package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} + + extras = { + "dev": [ + "nanobind==2.4.0", + "numpy", + "pre-commit", + "setuptools>=80", + "torch", + "typing_extensions", + ], + } + entry_points = { + "console_scripts": [ + "mlx.launch = mlx.distributed_run:main", + "mlx.distributed_config = mlx.distributed_run:distributed_config", + ] + } + + # Release builds for PyPi are in two stages. + # Each stage should be run from a clean build: + # python setup.py clean --all + # + # Stage 1: + # - Triggered with `MLX_BUILD_STAGE=1` + # - Include everything except backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) + # - Wheel has Python ABI and platform tags + # - Wheel should be built for the cross-product of python version and platforms + # - Package name is mlx and it depends on subpackage in stage 2 (e.g. mlx-metal) + # Stage 2: + # - Triggered with `MLX_BUILD_STAGE=2` + # - Includes only backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) + # - Wheel has only platform tags + # - Wheel should be built only for different platforms + # - Package name is back-end specific, e.g mlx-metal + if build_stage != 2: + if build_stage == 1: + if build_macos: + install_requires += [f"mlx-metal=={version}"] + else: + extras["cuda"] = [f"mlx-cuda=={version}"] + extras["cpu"] = [f"mlx-cpu=={version}"] + + _setup( + name="mlx", + packages=packages, + extras_require=extras, + entry_points=entry_points, + install_requires=install_requires, + package_data=package_data, + ) + else: + if build_macos: + name = "mlx-metal" + elif build_cuda: + name = "mlx-cuda" + else: + name = "mlx-cpu" + _setup( + name=name, + packages=["mlx"], + ) From cb349a291c4417ef29545f1f075f558d591de5f7 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 15 Jul 2025 16:36:13 +0900 Subject: [PATCH 154/156] [CUDA] Use cuda::std::complex in place of cuComplex (#2372) --- mlx/backend/cuda/binary.cu | 1 - mlx/backend/cuda/binary_two.cu | 1 - mlx/backend/cuda/device/atomic_ops.cuh | 4 +- mlx/backend/cuda/device/binary_ops.cuh | 55 ++--- mlx/backend/cuda/device/cast_op.cuh | 56 +++-- mlx/backend/cuda/device/complex.cuh | 61 ++++++ mlx/backend/cuda/device/cucomplex_math.cuh | 240 --------------------- mlx/backend/cuda/device/unary_ops.cuh | 168 ++++----------- mlx/backend/cuda/device/utils.cuh | 14 +- mlx/backend/cuda/jit_module.cpp | 4 +- mlx/backend/cuda/kernel_utils.cuh | 11 +- mlx/backend/cuda/reduce/reduce.cuh | 1 - mlx/backend/cuda/reduce/reduce_ops.cuh | 4 +- mlx/backend/cuda/unary.cu | 7 +- mlx/backend/cuda/utils.cpp | 2 +- 15 files changed, 169 insertions(+), 460 deletions(-) create mode 100644 mlx/backend/cuda/device/complex.cuh delete mode 100644 mlx/backend/cuda/device/cucomplex_math.cuh diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index c8586e638..3eade024d 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -3,7 +3,6 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 0918c579f..3ac8a9516 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -3,7 +3,6 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh index e0d3c3eac..5df246c0e 100644 --- a/mlx/backend/cuda/device/atomic_ops.cuh +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include @@ -48,7 +48,7 @@ inline __device__ void atomic_add(__half* out, __half val) { atomicAdd(out, val); } -inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +inline __device__ void atomic_add(complex64_t* out, complex64_t val) { #if __CUDA_ARCH__ < 900 atomic_add_general(out, val); #else diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 644786a92..575aced14 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -44,7 +44,7 @@ struct Remainder { } else { return x % y; } - } else if constexpr (cuda::std::is_same_v) { + } else if constexpr (is_complex_v) { return x % y; } else { T r = fmod(x, y); @@ -66,14 +66,12 @@ struct Equal { struct NaNEqual { template __device__ bool operator()(T x, T y) { - if constexpr (std::is_same_v) { + if constexpr (is_complex_v) { return x == y || - (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) && - isnan(cuCimagf(y))) || - (cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) && - isnan(cuCimagf(y))) || - (isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && - cuCimagf(x) == cuCimagf(y)); + (isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) && + isnan(y.imag())) || + (x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) || + (isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag()); } else { return x == y || (isnan(x) && isnan(y)); } @@ -111,17 +109,17 @@ struct LessEqual { struct LogAddExp { template __device__ T operator()(T x, T y) { - if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || - isnan(cuCimagf(y))) { + if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) || + isnan(y.imag())) { return { cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; } - auto max = cuCrealf(x) > cuCrealf(y) ? x : y; - auto min = cuCrealf(x) < cuCrealf(y) ? x : y; - auto min_real = cuCrealf(min); - auto max_real = cuCrealf(max); + auto max = x.real() > y.real() ? x : y; + auto min = x.real() < y.real() ? x : y; + auto min_real = min.real(); + auto max_real = max.real(); if (!isfinite(min_real) && (min_real == max_real)) { if (min_real < 0) { return min; @@ -150,8 +148,8 @@ struct Maximum { __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return max(x, y); - } else if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { return x; } return x > y ? x : y; @@ -169,8 +167,8 @@ struct Minimum { __device__ T operator()(T x, T y) { if constexpr (cuda::std::is_integral_v) { return min(x, y); - } else if constexpr (cuda::std::is_same_v) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) { + } else if constexpr (is_complex_v) { + if (isnan(x.real()) || isnan(x.imag())) { return x; } return x < y ? x : y; @@ -193,8 +191,8 @@ struct Multiply { struct NotEqual { template __device__ bool operator()(T x, T y) { - if constexpr (std::is_same_v) { - return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y); + if constexpr (is_complex_v) { + return x.real() != y.real() || x.imag() != y.imag(); } else { return x != y; } @@ -214,19 +212,8 @@ struct Power { base *= base; } return res; - } else if constexpr (cuda::std::is_same_v) { - if (base.y == 0 && base.x == 0) { - if (isnan(exp.x) || isnan(exp.y)) { - auto nan = cuda::std::numeric_limits::quiet_NaN(); - return make_cuFloatComplex(nan, nan); - } - return make_cuFloatComplex(0.0, 0.0); - } - auto x_theta = atan2f(base.y, base.x); - auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); - auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); - auto phase = exp.y * x_ln_r + exp.x * x_theta; - return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase)); + } else if constexpr (is_complex_v) { + return pow(base, exp); } else { return powf(base, exp); } diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 8da19ddf8..e10fde6dc 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -2,7 +2,8 @@ #pragma once -#include +#include "mlx/backend/cuda/device/complex.cuh" + #include #include #include @@ -20,50 +21,43 @@ struct CastOp { }; // Castings between complex and boolean. -// TODO: Should make a custom complex type. -template <> -struct CastOp { +template +struct CastOp, bool> { static constexpr bool is_castable = true; - __device__ bool operator()(cuComplex x) { - return x.x != 0 && x.y != 0; + __device__ bool operator()(complex_t x) { + return x.real() != 0 && x.imag() != 0; } }; -template <> -struct CastOp { +template +struct CastOp> { static constexpr bool is_castable = true; - __device__ cuComplex operator()(bool x) { - return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); + __device__ complex_t operator()(bool x) { + return x ? complex_t{1, 1} : complex_t{0, 0}; } }; // Converting a complex number to real number discards the imaginary part. -template -struct CastOp< - cuComplex, - DstT, - cuda::std::enable_if_t>> { - static constexpr bool is_castable = cuda::std::is_convertible_v; +template +struct CastOp, DstT, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; - __device__ DstT operator()(cuComplex x) { - static_assert(!cuda::std::is_same_v); - return static_cast(cuCrealf(x)); + __device__ DstT operator()(complex_t x) { + static_assert(!is_complex_v); + return static_cast(x.real()); } }; // Allow converting a real number to complex number. -template -struct CastOp< - SrcT, - cuComplex, - cuda::std::enable_if_t>> { - static constexpr bool is_castable = cuda::std::is_convertible_v; +template +struct CastOp, cuda::std::enable_if_t>> { + static constexpr bool is_castable = cuda::std::is_convertible_v; - __device__ cuComplex operator()(SrcT x) { - static_assert(!cuda::std::is_same_v); - return cuComplex{static_cast(x), 0}; + __device__ complex_t operator()(SrcT x) { + static_assert(!is_complex_v); + return complex_t{static_cast(x), 0}; } }; @@ -88,8 +82,7 @@ struct CastOp< SrcT, DstT, cuda::std::enable_if_t< - !cuda::std::is_convertible_v && - !cuda::std::is_same_v && + !cuda::std::is_convertible_v && !is_complex_v && (cuda::std::is_same_v || cuda::std::is_same_v)>> { static constexpr bool is_castable = true; @@ -104,8 +97,7 @@ struct CastOp< SrcT, DstT, cuda::std::enable_if_t< - !cuda::std::is_convertible_v && - !cuda::std::is_same_v && + !cuda::std::is_convertible_v && !is_complex_v && !cuda::std::is_same_v && !cuda::std::is_same_v && (cuda::std::is_same_v || diff --git a/mlx/backend/cuda/device/complex.cuh b/mlx/backend/cuda/device/complex.cuh new file mode 100644 index 000000000..8dfd23b46 --- /dev/null +++ b/mlx/backend/cuda/device/complex.cuh @@ -0,0 +1,61 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +// Make multiplication and division faster. +#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS + +#include +#include + +namespace mlx::core::cu { + +// TODO: Consider using a faster implementation as cuda::std::complex has to +// conform to C++ standard. +template +using complex_t = cuda::std::complex; + +using complex64_t = complex_t; +using complex128_t = complex_t; + +template +struct is_complex : cuda::std::false_type {}; + +template +struct is_complex> : cuda::std::true_type {}; + +template +inline constexpr bool is_complex_v = is_complex::value; + +// cuda::std::complex is missing some operators. +template +inline __host__ __device__ complex_t operator%( + complex_t a, + complex_t b) { + T r = a.real() - floor(a.real() / b.real()) * b.real(); + T i = a.imag() - floor(a.imag() / b.imag()) * b.imag(); + return complex_t{r, i}; +} + +template +inline __host__ __device__ bool operator<(complex_t a, complex_t b) { + return (a.real() * a.real() + a.imag() * a.imag()) < + (b.real() * b.real() + b.imag() * b.imag()); +} + +template +inline __host__ __device__ bool operator>(complex_t a, complex_t b) { + return b < a; +} + +template +inline __host__ __device__ bool operator<=(complex_t a, complex_t b) { + return !(a > b); +} + +template +inline __host__ __device__ bool operator>=(complex_t a, complex_t b) { + return !(a < b); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/cucomplex_math.cuh b/mlx/backend/cuda/device/cucomplex_math.cuh deleted file mode 100644 index 612650c06..000000000 --- a/mlx/backend/cuda/device/cucomplex_math.cuh +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright © 2025 Apple Inc. -// Copyright © 2017-2024 The Simons Foundation, Inc. -// -// FINUFFT is licensed under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance with the -// License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Forked from -// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h - -#pragma once - -#include - -// This header provides some helper functions for cuComplex types. -// It mainly wraps existing CUDA implementations to provide operator overloads -// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are -// all provided by CUDA - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCadd(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCsub(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCmul(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) { - return cuCdiv(a, b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) { - double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b)); - double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b)); - return make_cuDoubleComplex(r, i); -} - -__forceinline__ __host__ __device__ bool operator==( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b); -} - -__forceinline__ __host__ __device__ bool operator!=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return !(a == b); -} - -__forceinline__ __host__ __device__ bool operator>( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a)); - double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b)); - return mag_a > mag_b; -} - -__forceinline__ __host__ __device__ bool operator>=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return a > b || a == b; -} - -__forceinline__ __host__ __device__ bool operator<( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return b > a; -} - -__forceinline__ __host__ __device__ bool operator<=( - const cuDoubleComplex& a, - const cuDoubleComplex& b) { - return b > a || a == b; -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator+(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator-(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator*(double a, const cuDoubleComplex& b) { - return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b)); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(const cuDoubleComplex& a, double b) { - return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b); -} - -__forceinline__ __host__ __device__ cuDoubleComplex -operator/(double a, const cuDoubleComplex& b) { - double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b); - return make_cuDoubleComplex( - (a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCaddf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCsubf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCmulf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(const cuFloatComplex& a, const cuFloatComplex& b) { - return cuCdivf(a, b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator%(const cuFloatComplex& a, const cuFloatComplex& b) { - float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b)); - float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b)); - return make_cuFloatComplex(r, i); -} - -__forceinline__ __host__ __device__ bool operator==( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b); -} - -__forceinline__ __host__ __device__ bool operator!=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return !(a == b); -} - -__forceinline__ __host__ __device__ bool operator>( - const cuFloatComplex& a, - const cuFloatComplex& b) { - float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a)); - float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b)); - return mag_a > mag_b; -} - -__forceinline__ __host__ __device__ bool operator>=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return a > b || a == b; -} - -__forceinline__ __host__ __device__ bool operator<( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return b > a; -} - -__forceinline__ __host__ __device__ bool operator<=( - const cuFloatComplex& a, - const cuFloatComplex& b) { - return b > a || a == b; -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator+(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator-(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator*(float a, const cuFloatComplex& b) { - return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b)); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(const cuFloatComplex& a, float b) { - return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b); -} - -__forceinline__ __host__ __device__ cuFloatComplex -operator/(float a, const cuFloatComplex& b) { - float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b); - return make_cuFloatComplex( - (a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom); -} diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 447569eeb..aebed1e4d 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,12 +2,10 @@ #pragma once -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include -#include namespace mlx::core::cu { @@ -16,8 +14,6 @@ struct Abs { __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; - } else if constexpr (cuda::std::is_same_v) { - return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0}; } else { return abs(x); } @@ -29,8 +25,6 @@ struct ArcCos { __device__ T operator()(T x) { return acos(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcCosh { @@ -45,8 +39,6 @@ struct ArcSin { __device__ T operator()(T x) { return asin(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcSinh { @@ -61,8 +53,6 @@ struct ArcTan { __device__ T operator()(T x) { return atan(x); } - - __device__ cuComplex operator()(cuComplex x); }; struct ArcTanh { @@ -84,6 +74,8 @@ struct Ceil { __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; + } else if constexpr (is_complex_v) { + return T{ceil(x.real()), ceil(x.imag())}; } else { return ceil(x); } @@ -91,34 +83,23 @@ struct Ceil { }; struct Conjugate { - __device__ cuComplex operator()(cuComplex x) { - return {cuCrealf(x), -cuCimagf(x)}; + template + __device__ complex_t operator()(complex_t x) { + return conj(x); } }; struct Cos { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - cos(cuCrealf(x)) * cosh(cuCimagf(x)), - -sin(cuCrealf(x)) * sinh(cuCimagf(x))}; - } else { - return cos(x); - } + return cos(x); } }; struct Cosh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - cosh(cuCrealf(x)) * cos(cuCimagf(x)), - sinh(cuCrealf(x)) * sin(cuCimagf(x))}; - } else { - return cosh(x); - } + return cosh(x); } }; @@ -151,12 +132,7 @@ struct ErfInv { struct Exp { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto r = exp(cuda::std::complex{cuCrealf(x), cuCimagf(x)}); - return cuComplex{r.real(), r.imag()}; - } else { - return exp(x); - } + return exp(x); } }; @@ -178,6 +154,8 @@ struct Floor { __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; + } else if constexpr (is_complex_v) { + return T{floor(x.real()), floor(x.imag())}; } else { return floor(x); } @@ -185,30 +163,25 @@ struct Floor { }; struct Imag { - __device__ float operator()(cuComplex x) { - return cuCimagf(x); + template + __device__ auto operator()(complex_t x) { + return x.imag(); } }; struct Log { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto r = log(cuCrealf(Abs{}(x))); - auto i = atan2f(cuCimagf(x), cuCrealf(x)); - return {r, i}; - } else { - return log(x); - } + return log(x); } }; struct Log2 { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { auto y = Log{}(x); - return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F}; + return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F}; } else { return log2(x); } @@ -218,23 +191,17 @@ struct Log2 { struct Log10 { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - auto y = Log{}(x); - return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F}; - return y; - } else { - return log10(x); - } + return log10(x); } }; struct Log1p { template __device__ T operator()(T z) { - if constexpr (cuda::std::is_same_v) { - float x = cuCrealf(z); - float y = cuCimagf(z); - float zabs = cuCrealf(Abs{}(z)); + if constexpr (is_complex_v) { + float x = z.real(); + float y = z.imag(); + float zabs = Abs{}(z).real(); float theta = atan2f(y, x + 1); if (zabs < 0.5f) { float r = x * (2 + x) + y * y; @@ -261,8 +228,8 @@ struct LogicalNot { struct Negative { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return 0 - x; + if constexpr (is_complex_v) { + return T{0, 0} - x; } else { return -x; } @@ -270,16 +237,17 @@ struct Negative { }; struct Real { - __device__ float operator()(cuComplex x) { - return cuCrealf(x); + template + __device__ auto operator()(complex_t x) { + return x.real(); } }; struct Round { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return {rint(cuCrealf(x)), rint(cuCimagf(x))}; + if constexpr (is_complex_v) { + return {rint(x.real()), rint(x.imag())}; } else { return rint(x); } @@ -299,8 +267,8 @@ struct Sign { __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x != 0; - } else if constexpr (cuda::std::is_same_v) { - if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { + } else if constexpr (is_complex_v) { + if (x.real() == 0 && x.imag() == 0) { return x; } else { return x / Abs()(x); @@ -316,26 +284,14 @@ struct Sign { struct Sin { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - sin(cuCrealf(x)) * cosh(cuCimagf(x)), - cos(cuCrealf(x)) * sinh(cuCimagf(x))}; - } else { - return sin(x); - } + return sin(x); } }; struct Sinh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - return { - sinh(cuCrealf(x)) * cos(cuCimagf(x)), - cosh(cuCrealf(x)) * sin(cuCimagf(x))}; - } else { - return sinh(x); - } + return sinh(x); } }; @@ -351,77 +307,31 @@ struct Sqrt { __device__ T operator()(T x) { return sqrt(x); } - - __device__ cuComplex operator()(cuComplex x) { - auto xr = cuCrealf(x); - auto xi = cuCimagf(x); - if (xr == 0.0f && xi == 0.0f) { - return {0.0f, 0.0f}; - } - auto r = cuCrealf(Abs{}(x)); - auto a = sqrt((r + xr) / 2.0f); - auto b_abs = sqrt((r - xr) / 2.0f); - auto b = copysign(b_abs, xi); - return {a, b}; - } }; struct Rsqrt { template __device__ T operator()(T x) { - return rsqrt(x); - } - __device__ cuComplex operator()(cuComplex x) { - return 1.0f / Sqrt{}(x); + if constexpr (is_complex_v) { + return 1.0f / Sqrt{}(x); + } else { + return rsqrt(x); + } } }; struct Tan { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - float tan_a = tan(cuCrealf(x)); - float tanh_b = tanh(cuCimagf(x)); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; - } else { - return tan(x); - } + return tan(x); } }; struct Tanh { template __device__ T operator()(T x) { - if constexpr (cuda::std::is_same_v) { - float tanh_a = tanh(cuCrealf(x)); - float tan_b = tan(cuCimagf(x)); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; - } else { - return tanh(x); - } + return tanh(x); } }; -inline __device__ cuComplex ArcCos::operator()(cuComplex x) { - auto i = cuComplex{0.0, 1.0}; - auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); - return {cuCimagf(y), -cuCrealf(y)}; -}; - -inline __device__ cuComplex ArcSin::operator()(cuComplex x) { - auto i = cuComplex{0.0f, 1.0f}; - auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); - return {cuCimagf(y), -cuCrealf(y)}; -}; - -inline __device__ cuComplex ArcTan::operator()(cuComplex x) { - auto i = cuComplex{0.0f, 1.0f}; - auto ix = i * x; - return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); -}; - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index af022c141..73bc7ff63 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -8,9 +8,9 @@ #pragma once +#include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/config.h" -#include #include #include #include @@ -127,13 +127,13 @@ struct Limits { } }; -template <> -struct Limits { - static constexpr __host__ __device__ cuComplex max() { - return {Limits::max(), Limits::max()}; +template +struct Limits> { + static constexpr __host__ __device__ complex_t max() { + return {Limits::max(), Limits::max()}; } - static constexpr __host__ __device__ cuComplex min() { - return {Limits::min(), Limits::min()}; + static constexpr __host__ __device__ complex_t min() { + return {Limits::min(), Limits::min()}; } }; diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 4ce79999e..343db902e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -173,7 +173,7 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", - INCLUDE_PREFIX "cucomplex_math.cuh", + INCLUDE_PREFIX "complex.cuh", INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "scatter_ops.cuh", @@ -189,7 +189,7 @@ constexpr const char* g_headers[] = { jit_source_binary_ops, jit_source_cast_op, jit_source_config, - jit_source_cucomplex_math, + jit_source_complex, jit_source_fp16_math, jit_source_indexing, jit_source_scatter_ops, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index eeaf916c1..24c81f2fb 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -11,7 +11,6 @@ #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" -#include #include #include #include @@ -79,7 +78,7 @@ struct CTypeToCudaType { template <> struct CTypeToCudaType { - using type = cuComplex; + using type = cu::complex64_t; }; template @@ -91,10 +90,14 @@ inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; +// Type traits for detecting complex numbers. +template +inline constexpr bool is_complex_v = cuda::std::is_same_v || + cuda::std::is_same_v; + // Type traits for detecting complex or real floating point numbers. template -inline constexpr bool is_inexact_v = - is_floating_v || cuda::std::is_same_v; +inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index d0eb3f5c5..02e495594 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -3,7 +3,6 @@ #include #include "mlx/backend/common/reduce.h" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index bc4dce33e..31ba90433 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -151,7 +151,7 @@ struct ReduceInit { template struct ReduceInit { static constexpr __host__ __device__ auto value() { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { return T{0, 0}; } else { return cast_to::type>(0); @@ -162,7 +162,7 @@ struct ReduceInit { template struct ReduceInit { static constexpr __host__ __device__ auto value() { - if constexpr (cuda::std::is_same_v) { + if constexpr (is_complex_v) { return T{1, 0}; } else { return cast_to::type>(1); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 0d2754ef0..ddb32d05e 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -2,7 +2,6 @@ #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" @@ -71,10 +70,10 @@ constexpr bool supports_unary_op() { !std::is_same_v; } if (std::is_same_v || std::is_same_v) { - return std::is_same_v && !std::is_same_v; + return std::is_same_v && !mlx::core::is_complex_v; } if (std::is_same_v) { - return std::is_same_v && std::is_same_v; + return std::is_same_v && mlx::core::is_complex_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -88,7 +87,7 @@ constexpr bool supports_unary_op() { return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { - return std::is_same_v && std::is_same_v; + return mlx::core::is_complex_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index cc05428a8..1c12fa4df 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -61,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { case float64: return "double"; case complex64: - return "cuComplex"; + return "complex64_t"; default: return "unknown"; } From 2ba69bc8fa61b93f802b166ca6caac3daa9d0536 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 15 Jul 2025 14:22:07 -0700 Subject: [PATCH 155/156] lower memory uniform sampling (#2361) * lower memory uniform * use fp32 * fix --- mlx/random.cpp | 77 +++++++++++++++++------------------------- tests/random_tests.cpp | 4 +-- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 6c6d1eb95..def3169cb 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -92,29 +92,6 @@ T below_one() { return f; } -// Get the next representable value above -1.0 for half precision -// floating point types (fp16, bf16) -template -T above_minus_one() { - T f = T(-1.0); - uint16_t* m = (uint16_t*)&f; - *m -= 1; - return f; -} - -// Get the next representable value above -1.0 for half precision -// use std::nextafter as default case. -array above_minus_one_with_default(Dtype dtype) { - switch (dtype) { - case float16: - return array(above_minus_one(), dtype); - case bfloat16: - return array(above_minus_one(), dtype); - default: - return array(std::nextafter(-1.0f, 0.0f), dtype); - } -} - array uniform( const array& low, const array& high, @@ -139,31 +116,27 @@ array uniform( << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } - // Get random values between [0, nextafter(maxval, 0.0f)] since samples must + + // Get random values between [0, nextafter(1.0, 0.0)] since samples must // be in [low, high) - auto get_limits = [&dtype]() { + auto get_upper = [&dtype]() { switch (dtype) { case float32: - return std::make_pair( - array(std::nextafter(1.0f, 0.0f), float32), - array(std::numeric_limits::max(), float32)); + return array(std::nextafter(1.0f, 0.0f), float32); case float16: - return std::make_pair( - array(below_one(), float16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); case bfloat16: - return std::make_pair( - array(below_one(), bfloat16), - array(std::numeric_limits::max(), float32)); + return array(below_one(), float32); default: throw std::runtime_error("[uniform] Unsupported type."); } }; - auto [upper, maxval] = get_limits(); - auto out = bits(shape, size_of(dtype), key, stream); - out = astype(divide(out, maxval, stream), dtype, stream); - out = minimum(out, upper, stream); + auto upper = get_upper(); + auto maxval = array(std::numeric_limits::max(), float32); + auto out = bits(shape, size_of(float32), key, stream); + out = divide(out, maxval, stream); + out = astype(minimum(out, upper, stream), dtype, stream); return add(multiply(range, out, stream), lo, stream); } @@ -183,7 +156,7 @@ inline array complex_normal( const std::optional& key, StreamOrDevice s) { auto stream = to_stream(s); - auto low = above_minus_one_with_default(float32); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); auto high = array(1.0f, float32); shape.push_back(2); auto samples = @@ -207,18 +180,23 @@ array normal( StreamOrDevice s /* = {} */) { if (dtype == complex64) { return complex_normal(shape, loc, scale, key, s); + } else if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[normal] Can only generate uniform numbers with " + "floating point type."); } auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); auto applied_scale = array(std::sqrt(2.0), dtype); if (scale.has_value()) { applied_scale = multiply(applied_scale, astype(*scale, dtype, stream), stream); } - samples = multiply(applied_scale, erfinv(samples, stream), stream); + samples = astype(erfinv(samples, stream), dtype, stream); + samples = multiply(applied_scale, samples, stream); if (loc.has_value()) { samples = add(astype(*loc, dtype, stream), samples, stream); } @@ -469,16 +447,23 @@ array laplace( const float scale /* = 1.0 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { + if (!issubdtype(dtype, floating)) { + throw std::invalid_argument( + "[laplace] Can only generate uniform numbers with real" + "floating point type."); + } + auto stream = to_stream(s); - auto low = above_minus_one_with_default(dtype); - auto high = array(1.0f, dtype); - auto samples = uniform(low, high, shape, dtype, key, stream); + auto low = array(std::nextafter(-1.0f, 0.0f), float32); + auto high = array(1.0f, float32); + auto samples = uniform(low, high, shape, float32, key, stream); // Use inverse CDF to generate Laplacian noise samples = multiply( sign(samples, stream), log1p( multiply(array(-1.0f, dtype), abs(samples, stream), stream), stream), stream); + samples = astype(samples, dtype, stream); if (scale != 1.0) { samples = multiply(array(scale, dtype), samples, stream); diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 49f1f300b..6ddd37104 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -350,7 +350,7 @@ TEST_CASE("test random uniform") { // Check float16 { auto key = random::key(0); - auto out = random::uniform({100}, float16, key); + auto out = random::uniform({1000}, float16, key); CHECK_EQ(out.dtype(), float16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item()); @@ -360,7 +360,7 @@ TEST_CASE("test random uniform") { { auto key = random::key(0); - auto out = random::uniform({100}, bfloat16, key); + auto out = random::uniform({1000}, bfloat16, key); CHECK_EQ(out.dtype(), bfloat16); CHECK(all(less(out, array(1.0f))).item()); CHECK(all(greater_equal(out, array(0.0f))).item()); From d7734edd9ff74c174bd36c3b373047d81d268a9a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 15 Jul 2025 18:19:47 -0700 Subject: [PATCH 156/156] fix complex reduce + nan propagation in min and max (#2377) --- mlx/backend/cuda/device/complex.cuh | 9 ++++----- mlx/backend/cuda/reduce/reduce_ops.cuh | 24 ++++++++++++++++++++++++ python/tests/cuda_skip.py | 3 --- python/tests/test_reduce.py | 4 ++-- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/device/complex.cuh b/mlx/backend/cuda/device/complex.cuh index 8dfd23b46..03a7bff83 100644 --- a/mlx/backend/cuda/device/complex.cuh +++ b/mlx/backend/cuda/device/complex.cuh @@ -38,14 +38,13 @@ inline __host__ __device__ complex_t operator%( } template -inline __host__ __device__ bool operator<(complex_t a, complex_t b) { - return (a.real() * a.real() + a.imag() * a.imag()) < - (b.real() * b.real() + b.imag() * b.imag()); +inline __host__ __device__ bool operator>(complex_t a, complex_t b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } template -inline __host__ __device__ bool operator>(complex_t a, complex_t b) { - return b < a; +inline __host__ __device__ bool operator<(complex_t a, complex_t b) { + return operator>(b, a); } template diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 31ba90433..7f8cad0c4 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -69,6 +69,18 @@ struct Prod { struct Min { template __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } return a < b ? a : b; } @@ -81,6 +93,18 @@ struct Min { struct Max { template __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } return a > b ? a : b; } diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 7c9ff84ce..50cb8dcbe 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -2,9 +2,6 @@ cuda_skip = { "TestLoad.test_load_f8_e4m3", "TestLayers.test_quantized_embedding", "TestOps.test_dynamic_slicing", - "TestReduce.test_dtypes", - "TestReduce.test_nanpropagation", - "TestReduce.test_nanpropagation_complex64", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 9efd6c5c7..d6ddf353b 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -153,7 +153,7 @@ class TestReduce(mlx_tests.MLXTestCase): x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) check(x, (1, 3, 5, 7, 9)) - def test_nanpropagation(self): + def test_nan_propagation(self): dtypes = [ "uint8", "uint16", @@ -179,7 +179,7 @@ class TestReduce(mlx_tests.MLXTestCase): ref = getattr(np, op)(x_np, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True)) - def test_nanpropagation_complex64(self): + def test_nan_propagation_complex64(self): complex_array_1 = mx.array( [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64 ).reshape(2, 2)