From a3a632d567912e369eaccd9690231deff40973a9 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 1 May 2025 12:56:09 -0700 Subject: [PATCH 01/37] Fix the launcher when ran locally (#2147) --- python/mlx/distributed_run.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 9c946005b..404ecc349 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): # Repeat the stdout and stderr to the local machine to_read = [p.stdout.fileno(), p.stderr.fileno()] - to_write = [p.stdin.fileno()] + to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] pidfile = "" stdin_buffer = b"" + stdout_buffer = b"" + stderr_buffer = b"" while p.poll() is None: try: stdin_buffer += input_queue.get_nowait() @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): pass rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: - is_stdout = fd == p.stdout.fileno() - outfile = sys.stdout if is_stdout else sys.stderr msg = os.read(fd, 8192).decode(errors="ignore") # Fetch the PID file first if we haven't already @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): pidfile, *msg = msg.split("\n", maxsplit=1) msg = msg[0] if msg else "" - outfile.write(msg) - outfile.flush() + is_stdout = fd == p.stdout.fileno() + if is_stdout: + stdout_buffer += msg.encode() + else: + stderr_buffer += msg.encode() for fd in wlist: - if len(stdin_buffer) > 0: + if fd == p.stdin.fileno() and len(stdin_buffer) > 0: n = os.write(fd, stdin_buffer) stdin_buffer = stdin_buffer[n:] + elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: + n = os.write(fd, stdout_buffer) + stdout_buffer = stdout_buffer[n:] + elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: + n = os.write(fd, stderr_buffer) + stderr_buffer = stderr_buffer[n:] if stop: p.terminate() break From 9daa6b003f548569e9186502d5acb9b74b91bcbe Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 1 May 2025 15:02:02 -0700 Subject: [PATCH 02/37] fix shapeless export (#2148) --- mlx/export.cpp | 3 +++ python/tests/test_export_import.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mlx/export.cpp b/mlx/export.cpp index effc7a0c1..c9139e156 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -470,6 +470,9 @@ bool FunctionTable::match( if (x.dtype() != y.dtype()) { return false; } + if (x.ndim() != y.ndim()) { + return false; + } if (!shapeless && x.shape() != y.shape()) { return false; } diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 2b4b425ca..0190827bd 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -242,6 +242,7 @@ class TestExportImport(mlx_tests.MLXTestCase): def test_leaks(self): path = os.path.join(self.test_dir, "fn.mlxfn") + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: @@ -267,6 +268,24 @@ class TestExportImport(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_export_import_shapeless(self): + path = os.path.join(self.test_dir, "fn.mlxfn") + + def fun(*args): + return sum(args) + + with mx.exporter(path, fun, shapeless=True) as exporter: + exporter(mx.array(1)) + exporter(mx.array(1), mx.array(2)) + exporter(mx.array(1), mx.array(2), mx.array(3)) + + f2 = mx.import_function(path) + self.assertEqual(f2(mx.array(1))[0].item(), 1) + self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) + self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) + with self.assertRaises(ValueError): + f2(mx.array(10), mx.array([5, 10, 20])) + if __name__ == "__main__": unittest.main() From 481349495b8c3d094eb699e678077bbe1406392d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 18 Feb 2025 13:43:09 -0800 Subject: [PATCH 03/37] GPU Hadamard for large N (#1879) --- mlx/backend/common/hadamard.h | 6 +- mlx/backend/metal/hadamard.cpp | 240 ++++++++++++++------------- mlx/backend/metal/kernels/hadamard.h | 41 +++-- mlx/ops.cpp | 13 +- python/tests/test_ops.py | 26 ++- 5 files changed, 198 insertions(+), 128 deletions(-) diff --git a/mlx/backend/common/hadamard.h b/mlx/backend/common/hadamard.h index a8fed76b0..ba5c4e41e 100644 --- a/mlx/backend/common/hadamard.h +++ b/mlx/backend/common/hadamard.h @@ -99,7 +99,11 @@ inline std::pair decompose_hadamard(int n) { "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); } } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } return {n, m}; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index a7dfc5f17..89b970fce 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -1,9 +1,7 @@ // Copyright © 2024 Apple Inc. -#include - -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -15,7 +13,6 @@ namespace mlx::core { constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256; -constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB std::string gen_hadamard_codelet(int m) { // Generate a O(m^2) hadamard codelet for a given M @@ -60,121 +57,142 @@ std::string gen_hadamard_codelet(int m) { return source.str(); } -void Hadamard::eval_gpu(const std::vector& inputs, array& out) { - auto& s = stream(); +void hadamard_mn_contiguous( + const array& x, + array& y, + int m, + int n1, + int n2, + float scale, + metal::Device& d, + const Stream& s) { + int n = n1 * n2; + int read_width_n1 = n1 == 2 ? 2 : 4; + int read_width_n2 = n2 == 2 ? 2 : 4; + int read_width_m = (n == 2 || m == 28) ? 2 : 4; + int max_radix_1 = std::min(n1, 16); + int max_radix_2 = std::min(n2, 16); + float scale_n1 = 1.0; + float scale_n2 = (m == 1) ? scale : 1.0; + float scale_m = scale; - auto& in = inputs[0]; + // n2 is a row contiguous power of 2 hadamard transform + MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1); + MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1); - std::vector copies; - // Only support the last axis for now - int axis = in.ndim() - 1; - auto check_input = [&copies, &s](const array& x) { - // TODO(alexbarron) pass strides to kernel to relax this constraint - bool no_copy = x.flags().row_contiguous; - if (no_copy) { - return x; - } else { - copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); - copy_gpu(x, copies.back(), CopyType::General, s); - return copies.back(); + // n1 is a strided power of 2 hadamard transform with stride n2 + MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1); + MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2); + + // m is a strided hadamard transform with stride n = n1 * n2 + MTL::Size group_dims_m( + std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1); + MTL::Size grid_dims_m( + group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1); + + // Make the kernel + std::string kname; + kname.reserve(32); + concatenate(kname, "hadamard_", n * m, "_", type_to_name(x)); + auto lib = d.get_library(kname, [&]() { + std::string kernel; + concatenate( + kernel, + metal::utils(), + gen_hadamard_codelet(m), + metal::hadamard(), + get_template_definition( + "n2" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n2, + max_radix_2, + read_width_n2)); + if (n1 > 1) { + kernel += get_template_definition( + "n1" + kname, + "hadamard_n", + get_type_string(x.dtype()), + n1, + max_radix_1, + read_width_n1, + n2); } - }; - const array& in_contiguous = check_input(in); - - if (in_contiguous.is_donatable()) { - out.copy_shared_buffer(in_contiguous); - } else { - out.set_data(allocator::malloc(out.nbytes())); - } - - int n, m; - std::tie(n, m) = decompose_hadamard(in.shape(axis)); - - if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { - throw std::invalid_argument( - "[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI"); - } - - int max_radix = std::min(n, 16); - // Use read_width 2 for m = 28 to avoid register spilling - int read_width = (n == 2 || m == 28) ? 2 : 4; - - std::ostringstream kname; - kname << "hadamard_" << n * m << "_" << type_to_name(out); - auto kernel_name = kname.str(); - auto& d = metal::device(s.device); - const auto& lib_name = kernel_name; - auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - auto codelet = gen_hadamard_codelet(m); - kernel_source << metal::utils() << codelet << metal::hadamard(); - kernel_source << get_template_definition( - "n" + kernel_name, - "hadamard_n", - get_type_string(in.dtype()), - n, - max_radix, - read_width); - kernel_source << get_template_definition( - "m" + kernel_name, - "hadamard_m", - get_type_string(in.dtype()), - n, - m, - read_width); - return kernel_source.str(); + if (m > 1) { + kernel += get_template_definition( + "m" + kname, + "hadamard_m", + get_type_string(x.dtype()), + n, + m, + read_width_m); + } + return kernel; }); - int batch_size = in.size() / n; - int threads_per = n / max_radix; - - auto& compute_encoder = d.get_command_encoder(s.index); - - auto launch_hadamard = [&](const array& in, - array& out, - const std::string& kernel_name, - float scale) { - auto kernel = d.get_kernel(kernel_name, lib); - assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - + // Launch the strided transform for n1 + if (n1 > 1) { + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n1" + kname, lib); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder.set_bytes(scale, 2); - - MTL::Size group_dims = MTL::Size(1, threads_per, 1); - MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); - }; - - if (m > 1) { - // When m is greater than 1, we decompose the - // computation into two uploads to the GPU: - // - // e.g. len(x) = 12*4 = 48, m = 12, n = 4 - // - // y = h48 @ x - // - // Upload 1: - // tmp = a.reshape(12, 4) @ h4 - // - // Upload 2: - // y = h12 @ tmp - array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc(temp.nbytes())); - copies.push_back(temp); - - launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); - - // Metal sometimes reports 256 max threads per group for hadamard_m kernel - threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); - batch_size = in.size() / m / read_width / threads_per; - launch_hadamard(temp, out, "m" + kernel_name, scale_); - } else { - launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n1, 2); + compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1); } - d.add_temporaries(std::move(copies), s.index); + // Launch the transform for n2 + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel("n2" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(n1 > 1 ? y : x, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_n2, 2); + compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2); + + // Launch the strided transform for m + if (m > 1) { + auto kernel = d.get_kernel("m" + kname, lib); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(y, 0); + compute_encoder.set_output_array(y, 1); + compute_encoder.set_bytes(scale_m, 2); + compute_encoder.dispatch_threads(grid_dims_m, group_dims_m); + } +} + +void Hadamard::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + auto& in = inputs[0]; + + // Split the hadamard transform so that all of them work on vectors smaller + // than 8192 elements. + // + // We decompose it in the following way: + // + // n = m * n1 * n2 = m * 2^k1 * 2^k2 + // + // where m is in (1, 12, 20, 28) and n1 and n2 <= 8192 + auto [n, m] = decompose_hadamard(in.shape().back()); + int n1 = 1, n2 = n; + if (n > 8192) { + for (n2 = 2; n2 * n2 < n; n2 *= 2) { + } + n1 = n / n2; + } + + if (in.flags().row_contiguous) { + if (in.is_donatable()) { + out.copy_shared_buffer(in); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s); + } else { + copy_gpu(in, out, CopyType::General, s); + hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/hadamard.h b/mlx/backend/metal/kernels/hadamard.h index 93e2fb8a8..9f2311c10 100644 --- a/mlx/backend/metal/kernels/hadamard.h +++ b/mlx/backend/metal/kernels/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e7abe12db..4aa5e88b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -473,8 +473,19 @@ array hadamard_transform( std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) - float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); + int n = a.ndim() > 0 ? a.shape(-1) : 1; + float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; + + // Nothing to do for a scalar + if (n == 1) { + if (scale == 1) { + return a; + } + + return multiply(a, array(scale, dtype), s); + } + return array( a.shape(), dtype, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d840eac7d..d9e143d82 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2868,11 +2868,33 @@ class TestOps(mlx_tests.MLXTestCase): h28 = parse_h_string(h28_str) + x = mx.array(5) + y = mx.hadamard_transform(x) + self.assertEqual(y.item(), 5) + + x = mx.array(5) + y = mx.hadamard_transform(x, scale=0.2) + self.assertEqual(y.item(), 1) + + x = mx.random.normal((8, 8, 1)) + y = mx.hadamard_transform(x) + self.assertTrue(mx.all(y == x).item()) + + # Too slow to compare to numpy so let's compare CPU to GPU + if mx.default_device() == mx.gpu: + rk = mx.random.key(42) + for k in range(14, 17): + for m in [1, 3, 5, 7]: + x = mx.random.normal((4, m * 2**k), key=rk) + y1 = mx.hadamard_transform(x, stream=mx.cpu) + y2 = mx.hadamard_transform(x, stream=mx.gpu) + self.assertLess(mx.abs(y1 - y2).max().item(), 5e-6) + np.random.seed(7) - tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 15)) + tests = product([np.float32, np.float16, np.int32], [1, 28], range(1, 14)) for dtype, m, k in tests: # skip large m=28 cases because they're very slow in NumPy - if (m > 1 and k > 8) or (dtype != np.float16 and k == 14): + if m > 1 and k > 8: continue with self.subTest(dtype=dtype, m=m, k=k): n = m * 2**k From 9c5e7da5079cf98f48df150c8bed5c3c0043d22c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 May 2025 15:08:50 -0700 Subject: [PATCH 04/37] fix compile merging (#2150) --- mlx/compile.cpp | 9 +++++++++ tests/compile_tests.cpp | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ff5c8f9e..2baeb6fcf 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } + + // If src is a parent of dst, remove it from dst's parents + for (auto it = pairs.begin(); it != pairs.end();) { + if (it->first.id() == src.id()) { + it = pairs.erase(it); + } else { + it++; + } + } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 66511682d..96552ef9d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") { out = cfun2({array(0)}); CHECK_EQ(out[0].item(), 3); } + +TEST_CASE("test compile with no-ops") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(stop_gradient(abs(inputs[0])))}; + }; + auto in = array(1.0); + auto out = compile(fun)({in})[0]; + CHECK_EQ(out.inputs()[0].id(), in.id()); +} From 825124af8ffd32d0f2f7d8f8eca83c8c3eb510a7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 06:15:04 -0700 Subject: [PATCH 05/37] fix bw for elementwise ops (#2151) * fix bw for elementwise ops * add compile * fix * fix * fix * fix --- mlx/backend/metal/binary.cpp | 15 ++++-- mlx/backend/metal/compiled.cpp | 38 +++++++++---- mlx/backend/metal/copy.cpp | 27 +++++++--- mlx/backend/metal/kernels/binary.h | 51 ++++++++++++------ mlx/backend/metal/kernels/binary_two.h | 75 ++++++++++++++++---------- mlx/backend/metal/kernels/copy.h | 34 ++++++++---- mlx/backend/metal/kernels/ternary.h | 17 ++++-- mlx/backend/metal/kernels/unary.h | 17 ++++-- mlx/backend/metal/kernels/utils.h | 8 +++ mlx/backend/metal/ternary.cpp | 14 +++-- mlx/backend/metal/unary.cpp | 17 ++++-- mlx/backend/metal/utils.h | 8 +++ 12 files changed, 232 insertions(+), 89 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -90,7 +90,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +137,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..db20f938c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -64,6 +64,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -83,6 +84,9 @@ inline void build_kernel( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +96,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -110,6 +115,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -193,7 +201,7 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { + if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } @@ -272,6 +280,7 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -284,7 +293,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -295,7 +306,8 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -468,6 +480,13 @@ void Compiled::eval_gpu( if (!contiguous) { compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,12 +496,13 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + int work_per_thread = get_work_per_thread(outputs[0].dtype()); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..ee004359f 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -104,6 +104,8 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(in.dtype()); } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,13 +167,19 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -214,14 +222,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); + int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..ffc33ad82 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,85 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..e261d33c4 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,103 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } template diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..2469d1f3d 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,53 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,32 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..b5eaab2e9 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,28 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + out[index + i] = Op()(in[index + i]); + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } template < diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 1170d5576..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -106,13 +106,19 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..368e693a9 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -34,18 +34,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { + work_per_thread = get_work_per_thread(in.dtype()); kernel_name = (large ? "v2" : "v"); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +76,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..f9245a6d6 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + } // namespace mlx::core From af705590ac9335105a5a026de4fc68ee6e747a9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 13:13:03 -0700 Subject: [PATCH 06/37] fix batched vector sdpa (#2152) --- mlx/backend/metal/kernels/sdpa_vector.h | 12 +- .../metal/scaled_dot_product_attention.cpp | 103 ++++++++++-------- python/tests/test_fast_sdpa.py | 40 +++++++ 3 files changed, 105 insertions(+), 50 deletions(-) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..8258e9c14 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 845962d01..d75e6d87d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -154,9 +154,9 @@ void sdpa_vector( int gqa_factor = q.shape(1) / k.shape(1); int N = k.shape(2); int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(1024, 1, 1); @@ -199,11 +199,10 @@ void sdpa_vector( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 11 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14); compute_encoder.set_bytes(head_stride, 15); @@ -238,9 +237,10 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_head_stride = k.strides()[1]; + + size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; - size_t v_head_stride = v.strides()[1]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(B, q.shape(2), blocks); @@ -302,11 +302,10 @@ void sdpa_vector_2pass( if (has_mask) { auto& m = *mask; compute_encoder.set_input_array(m, 13 + float_mask); - auto nd = m.ndim(); - int32_t kv_seq_stride = - nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; - int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; - int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); compute_encoder.set_bytes(kv_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16); compute_encoder.set_bytes(head_stride, 17); @@ -368,18 +367,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - // Checks if arr is row contiguous or the sequence and head dimension are - // transposed - auto is_contiguous_or_head_seq_transposed = [](const array& arr) { - if (arr.flags().row_contiguous) { - return true; - } - auto& strides = arr.strides(); - auto& shape = arr.shape(); - return (strides[3] == 1) && (strides[2] == shape[3] * shape[1]) && - (strides[1] == shape[3]) && (strides[0] == strides[2] * shape[2]); - }; - // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { return arr.strides(-1) == 1; @@ -387,30 +374,58 @@ void ScaledDotProductAttention::eval_gpu( // We are in vector mode ie single query if (q_pre.shape(2) <= 8) { - const auto& q = copy_unless(is_contiguous_or_head_seq_transposed, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); // Donate the query if possible - if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) && - q.size() == o.size()) { + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { o.copy_shared_buffer(q); } else { - if (o.shape(2) == 1) { - o.set_data(allocator::malloc(o.nbytes())); - } else { - auto strides = o.strides(); - strides[2] = o.shape(1) * o.shape(3); - strides[1] = o.shape(3); - auto flags = q.flags(); - flags.row_contiguous = q.shape(1) == 1; - o.set_data( - allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); - } + o.set_data(allocator::malloc(o.nbytes())); } - auto mask = - inputs.size() > 3 ? std::optional{inputs[3]} : std::nullopt; + auto mask_copy_unless = [&q](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return arr.flags().row_contiguous || q.shape(0) == 1 || q.shape(1) == 1 || + (strides[0] == strides[1] * shape[1]); + }; + + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(mask_copy_unless, inputs[3])} + : std::nullopt; // We route to the 2 pass fused attention if // - The device is large and the sequence length long diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index d35a2b1da..8f55d41e3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -473,6 +473,46 @@ class TestFastSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_vector_batched(self): + D = 64 + q = mx.random.normal(shape=(2, 1, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) + v = mx.random.normal(shape=(2, 2, 3, D)) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) + ref = mlx_ref_attn(q, k, v) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + + q = mx.random.normal(shape=(2, 4, 3, D)) + k = mx.random.normal(shape=(2, 1, 3, D)) + v = mx.random.normal(shape=(2, 1, 3, D)) + mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) + out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) + ref = mlx_ref_attn(q, k, v, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + class TestSDPA(mlx_tests.MLXTestCase): @property From 1683975acf2f007ba94a0a53241149474f0c070b Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 6 May 2025 05:45:29 +0900 Subject: [PATCH 07/37] Move common gpu primitives to backend/gpu (#2145) --- mlx/CMakeLists.txt | 1 + mlx/backend/gpu/CMakeLists.txt | 5 + mlx/backend/gpu/copy.cpp | 49 ++++ mlx/backend/{metal => gpu}/copy.h | 2 + mlx/backend/gpu/primitives.cpp | 217 ++++++++++++++++++ mlx/backend/gpu/slicing.cpp | 44 ++++ mlx/backend/{metal => gpu}/slicing.h | 0 mlx/backend/metal/conv.cpp | 2 +- mlx/backend/metal/copy.cpp | 44 +--- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/backend/metal/distributed.cpp | 2 +- mlx/backend/metal/fft.cpp | 4 +- mlx/backend/metal/hadamard.cpp | 2 +- mlx/backend/metal/indexing.cpp | 2 +- mlx/backend/metal/logsumexp.cpp | 2 +- mlx/backend/metal/matmul.cpp | 2 +- mlx/backend/metal/normalization.cpp | 2 +- mlx/backend/metal/primitives.cpp | 182 +-------------- mlx/backend/metal/quantized.cpp | 2 +- mlx/backend/metal/reduce.cpp | 2 +- mlx/backend/metal/rope.cpp | 2 +- .../metal/scaled_dot_product_attention.cpp | 2 +- mlx/backend/metal/scan.cpp | 2 +- mlx/backend/metal/slicing.cpp | 39 +--- mlx/backend/metal/softmax.cpp | 2 +- mlx/backend/metal/sort.cpp | 2 +- 26 files changed, 340 insertions(+), 277 deletions(-) create mode 100644 mlx/backend/gpu/CMakeLists.txt create mode 100644 mlx/backend/gpu/copy.cpp rename mlx/backend/{metal => gpu}/copy.h (98%) create mode 100644 mlx/backend/gpu/primitives.cpp create mode 100644 mlx/backend/gpu/slicing.cpp rename mlx/backend/{metal => gpu}/slicing.h (100%) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 465954d6f..00898e73e 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,6 +47,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt new file mode 100644 index 000000000..0396ae03a --- /dev/null +++ b/mlx/backend/gpu/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp new file mode 100644 index 000000000..6127ac921 --- /dev/null +++ b/mlx/backend/gpu/copy.cpp @@ -0,0 +1,49 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/gpu/copy.h similarity index 98% rename from mlx/backend/metal/copy.h rename to mlx/backend/gpu/copy.h index 37c60df42..020f579e4 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/gpu/copy.h @@ -5,6 +5,8 @@ #include "mlx/backend/common/copy.h" #include "mlx/stream.h" +#include + namespace mlx::core { // Generic copy inplace diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp new file mode 100644 index 000000000..cd9296075 --- /dev/null +++ b/mlx/backend/gpu/primitives.cpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#include + +#define MLX_PROFILER_RANGE(message) + +namespace mlx::core { + +namespace { + +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsStrided::eval_gpu"); + eval(inputs, out); +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsType::eval_gpu"); + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Broadcast::eval_gpu"); + eval(inputs, out); +} + +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Concatenate::eval_gpu"); + concatenate_gpu(inputs, out, axis_, stream()); +} + +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Contiguous::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { + out.copy_shared_buffer(in); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Copy::eval_gpu"); + eval(inputs, out); +} + +void CustomTransforms::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Depends::eval_gpu"); + eval(inputs, outputs); +} + +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); + eval(inputs, out); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Full::eval_gpu"); + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Flatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); + eval(inputs, out); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Reshape::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Split::eval_gpu"); + eval(inputs, outputs); +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Slice::eval_gpu"); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + slice_gpu(in, out, start_indices_, strides_, stream()); +} + +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Squeeze::eval_gpu"); + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("StopGradient::eval_gpu"); + eval(inputs, out); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Transpose::eval_gpu"); + eval(inputs, out); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Unflatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void View::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("View::eval_gpu"); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp new file mode 100644 index 000000000..fde2a01cd --- /dev/null +++ b/mlx/backend/gpu/slicing.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s) { + slice(in, out, start_indices, strides); +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s) { + // Fill output with val + fill_gpu(val, out, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/gpu/slicing.h similarity index 100% rename from mlx/backend/metal/slicing.h rename to mlx/backend/gpu/slicing.h diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9075ea4c5..ae31a6cff 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -5,7 +5,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index ee004359f..8dfe15c11 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,35 +1,15 @@ // Copyright © 2023-2024 Apple Inc. -#include - +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} - -void copy_gpu(const array& in, array& out, CopyType ctype) { - copy_gpu(in, out, ctype, out.primitive().stream()); -} - void copy_gpu_inplace( const array& in, array& out, @@ -184,28 +164,6 @@ void copy_gpu_inplace( } } -void copy_gpu_inplace( - const array& in, - array& out, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); -} - -void copy_gpu_inplace( - const array& in, - array& out, - const Strides& i_strides, - int64_t i_offset, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); -} - void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8a672289a..ea4f258cc 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 82e8fff7d..a800d2e0f 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,7 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 153c62c02..011eb7ebb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -7,10 +7,10 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 89b970fce..65a877151 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -3,7 +3,7 @@ #include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a263051..cccfd908a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 4901190e1..e53bc58d9 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f55d20c9f..71221f8d9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -7,7 +7,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c1d993d2a..21142183e 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..860e9ddd7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,10 +7,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - static array compute_dynamic_offset( const array& indices, const Strides& strides, @@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void AsType::eval_gpu(const std::vector& inputs, array& out) { - CopyType ctype = - inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; - copy_gpu(inputs[0], out, ctype); -} - -void AsStrided::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Broadcast::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - concatenate_gpu(inputs, out, axis_, stream()); -} - -void Contiguous::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; - if (in.buffer_size() <= out.nbytes() + extra_bytes && - (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous))) { - out.copy_shared_buffer(in); - } else { - copy_gpu(in, out, CopyType::General); - } -} - -void Copy::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void CustomTransforms::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Depends::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Full::eval_gpu(const std::vector& inputs, array& out) { - auto in = inputs[0]; - CopyType ctype; - if (in.data_size() == 1) { - ctype = CopyType::Scalar; - } else if (in.flags().contiguous) { - ctype = CopyType::Vector; - } else { - ctype = CopyType::General; - } - copy_gpu(in, out, ctype); -} - -void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Flatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Unflatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } -void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Pad::eval_gpu(const std::vector& inputs, array& out) { - // Inputs must be base input array and scalar val array - assert(inputs.size() == 2); - auto& in = inputs[0]; - auto& val = inputs[1]; - - // Padding value must be a scalar - assert(val.size() == 1); - - // Padding value, input and output must be of the same type - assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - - pad_gpu(in, val, out, axes_, low_pad_size_, stream()); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void Reshape::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Split::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Slice::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - slice_gpu(in, out, start_indices_, strides_, stream()); -} - void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { /* const Stream& s = */ stream()); } -void Squeeze::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void StopGradient::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Transpose::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -537,35 +390,4 @@ void LUF::eval_gpu( throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } -void View::eval_gpu(const std::vector& inputs, array& out) { - auto& in = inputs[0]; - auto ibytes = size_of(in.dtype()); - auto obytes = size_of(out.dtype()); - // Conditions for buffer copying (disjunction): - // - type size is the same - // - type size is smaller and the last axis is contiguous - // - the entire array is row contiguous - if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || - in.flags().row_contiguous) { - auto strides = in.strides(); - for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { - strides[i] *= ibytes; - strides[i] /= obytes; - } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); - } else { - auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc(tmp.nbytes())); - copy_gpu_inplace(in, tmp, CopyType::General, stream()); - - auto flags = out.flags(); - flags.contiguous = true; - flags.row_contiguous = true; - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); - } -} - } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..11a2355cc 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -4,7 +4,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c5650bdd7..8cb55ba58 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 060758333..d8201afe6 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d75e6d87d..3c7b7ff19 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b1800fea9..3c4051105 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 6ab08a108..3e1a8b541 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,21 +2,12 @@ #include -#include "mlx/backend/common/slicing.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" namespace mlx::core { -void slice_gpu( - const array& in, - array& out, - const Shape& start_indices, - const Shape& strides, - const Stream& s) { - slice(in, out, start_indices, strides); -} - void concatenate_gpu( const std::vector& inputs, array& out, @@ -48,30 +39,4 @@ void concatenate_gpu( } } -void pad_gpu( - const array& in, - const array& val, - array& out, - const std::vector& axes, - const Shape& low_pad_size, - const Stream& s) { - // Fill output with val - fill_gpu(val, out, s); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { - auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; - data_offset += out.strides()[ax] * low_pad_size[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 224721a50..59662b05d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 543dfd180..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -2,7 +2,7 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" From 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 17:30:50 -0700 Subject: [PATCH 08/37] fix input coherent kernel launch (#2153) --- mlx/backend/metal/fence.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index d4a88d983..5abdf7309 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -138,7 +138,7 @@ void Fence::update(Stream stream, const array& x) { compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x, 0); compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(group_dims, grid_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Barrier on previous kernels compute_encoder.barrier(); From 0cae0bdac83bbf5b3d1da3ca53f1f7eb95981d30 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 7 May 2025 13:26:46 +0900 Subject: [PATCH 09/37] CUDA backend: backbone (#2075) --- CMakeLists.txt | 5 + mlx/CMakeLists.txt | 10 +- mlx/backend/cuda/CMakeLists.txt | 57 ++++++ mlx/backend/cuda/allocator.cpp | 154 ++++++++++++++ mlx/backend/cuda/allocator.h | 58 ++++++ mlx/backend/cuda/copy.cpp | 26 +++ mlx/backend/cuda/device.cpp | 117 +++++++++++ mlx/backend/cuda/device.h | 131 ++++++++++++ mlx/backend/cuda/dtype_utils.cuh | 35 ++++ mlx/backend/cuda/eval.cpp | 68 +++++++ mlx/backend/cuda/event.cu | 265 +++++++++++++++++++++++++ mlx/backend/cuda/event.h | 66 ++++++ mlx/backend/cuda/fence.cu | 70 +++++++ mlx/backend/cuda/kernels/arange.cuh | 15 ++ mlx/backend/cuda/kernels/fp16_math.cuh | 107 ++++++++++ mlx/backend/cuda/primitives.cu | 163 +++++++++++++++ mlx/backend/cuda/slicing.cpp | 15 ++ mlx/backend/cuda/utils.cpp | 26 +++ mlx/backend/cuda/utils.h | 36 ++++ mlx/backend/cuda/worker.cpp | 90 +++++++++ mlx/backend/cuda/worker.h | 68 +++++++ tests/CMakeLists.txt | 2 +- 22 files changed, 1582 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/CMakeLists.txt create mode 100644 mlx/backend/cuda/allocator.cpp create mode 100644 mlx/backend/cuda/allocator.h create mode 100644 mlx/backend/cuda/copy.cpp create mode 100644 mlx/backend/cuda/device.cpp create mode 100644 mlx/backend/cuda/device.h create mode 100644 mlx/backend/cuda/dtype_utils.cuh create mode 100644 mlx/backend/cuda/eval.cpp create mode 100644 mlx/backend/cuda/event.cu create mode 100644 mlx/backend/cuda/event.h create mode 100644 mlx/backend/cuda/fence.cu create mode 100644 mlx/backend/cuda/kernels/arange.cuh create mode 100644 mlx/backend/cuda/kernels/fp16_math.cuh create mode 100644 mlx/backend/cuda/primitives.cu create mode 100644 mlx/backend/cuda/slicing.cpp create mode 100644 mlx/backend/cuda/utils.cpp create mode 100644 mlx/backend/cuda/utils.h create mode 100644 mlx/backend/cuda/worker.cpp create mode 100644 mlx/backend/cuda/worker.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e2002fc94..ab8aea443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) +option(MLX_BUILD_CUDA "Build cuda backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -83,6 +84,10 @@ if(MLX_BUILD_METAL) set(QUARTZ_LIB "-framework QuartzCore") endif() +if(MLX_BUILD_CUDA) + enable_language(CUDA) +endif() + if(MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 00898e73e..4ba9b33dd 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,10 +47,18 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp) +endif() + +if(MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda) +endif() + +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) +else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) endif() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt new file mode 100644 index 000000000..54d651005 --- /dev/null +++ b/mlx/backend/cuda/CMakeLists.txt @@ -0,0 +1,57 @@ +# Filename rules in cuda backend: +# +# * Use .cu/.cuh if code contains device code, and .cpp/.h if not. +# * Device-only kernel code should be put in kernels/ subdir. +# * Files in kernels/ subdir should not include files outside. +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/event.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) + +target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) + +# Enable defining device lambda functions. +target_compile_options(mlx + PRIVATE "$<$:--extended-lambda>") + +# Compute capability 7 is required for synchronization between CPU/GPU with +# managed memory. TODO: Add more architectures for potential performance gain. +set(MLX_CUDA_ARCHITECTURES + "75;80" + CACHE STRING "CUDA architectures") +message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") +set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES + "${MLX_CUDA_ARCHITECTURES}") + +# Use fixed version of CCCL. +FetchContent_Declare( + cccl + URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") +FetchContent_MakeAvailable(cccl) +target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") + +# Use fixed version of NVTX. +FetchContent_Declare( + nvtx3 + GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git + GIT_TAG v3.1.1 + GIT_SHALLOW TRUE + SOURCE_SUBDIR c EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(nvtx3) +target_link_libraries(mlx PUBLIC $) + +# Make cuda runtime APIs available in non-cuda files. +find_package(CUDAToolkit REQUIRED) +target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + +# Suppress nvcc warnings on MLX headers. +target_compile_options(mlx PRIVATE $<$:-Xcudafe + --diag_suppress=997>) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp new file mode 100644 index 000000000..203534e21 --- /dev/null +++ b/mlx/backend/cuda/allocator.cpp @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/worker.h" + +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +CudaAllocator::CudaAllocator() { + // TODO: Set memory limit for multi-device. + size_t free, total; + CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); + memory_limit_ = total * 0.8; +} + +Buffer CudaAllocator::malloc(size_t size) { + // TODO: Check memory limit. + auto* buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error( + fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + std::lock_guard lock(mutex_); + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; +} + +void CudaAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + // If free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([buffer]() { allocator().free(buffer); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + size_t size = buf->size; + cudaFree(buf->data); + delete buf; + std::lock_guard lock(mutex_); + active_memory_ -= size; +} + +size_t CudaAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void CudaAllocator::register_this_thread() { + std::lock_guard lock(worker_mutex_); + allowed_threads_.insert(std::this_thread::get_id()); +} + +size_t CudaAllocator::get_active_memory() const { + return active_memory_; +} + +size_t CudaAllocator::get_peak_memory() const { + return peak_memory_; +} + +void CudaAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t CudaAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t CudaAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +CudaAllocator& allocator() { + // By creating the |allocator_| on heap, the destructor of CudaAllocator + // will not be called on exit and buffers in the cache will be leaked. This + // can save some time at program exit. + static CudaAllocator* allocator_ = new CudaAllocator; + return *allocator_; +} + +} // namespace cu + +namespace allocator { + +Allocator& allocator() { + return cu::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + return static_cast(ptr_)->data; +} + +} // namespace allocator + +size_t get_active_memory() { + return cu::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return cu::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return cu::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return cu::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return cu::allocator().get_memory_limit(); +} + +// TODO: Implement buffer cache. +size_t get_cache_memory() { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; +} +size_t set_wired_limit(size_t) { + return 0; +} +void clear_cache() {} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h new file mode 100644 index 000000000..6c418ee7e --- /dev/null +++ b/mlx/backend/cuda/allocator.h @@ -0,0 +1,58 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class Worker; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // Register current thread as safe to free buffers. + // In cuda freeing a buffer implicitly synchronizes stream, and for threads + // that may be waited by gpu stream (for example cpu stream threads), freeing + // buffers there would result in dead lock. + void register_this_thread(); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + + private: + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex worker_mutex_; + std::unique_ptr worker_; + std::set allowed_threads_; + + std::mutex mutex_; + size_t memory_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; +}; + +CudaAllocator& allocator(); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/copy.cpp b/mlx/backend/cuda/copy.cpp new file mode 100644 index 000000000..d0413d989 --- /dev/null +++ b/mlx/backend/cuda/copy.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& strides_in_pre, + const Strides& strides_out_pre, + int64_t inp_offset, + int64_t out_offset, + CopyType ctype, + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { + throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); +} + +void fill_gpu(const array& val, array& out, const Stream& s) { + throw std::runtime_error("fill_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp new file mode 100644 index 000000000..a28ffa35e --- /dev/null +++ b/mlx/backend/cuda/device.cpp @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/metal/metal.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} + +void DeviceStream::synchronize() { + cudaStreamSynchronize(stream_); +} + +cudaStream_t DeviceStream::schedule_cuda_stream() { + // TODO: Return a stream that maximizes parallelism. + return stream_; +} + +cudaStream_t DeviceStream::last_cuda_stream() { + return stream_; +} + +CommandEncoder& DeviceStream::get_encoder() { + if (!encoder_) { + encoder_ = std::make_unique(*this); + } + return *encoder_; +} + +Device::Device(int device) : device_(device) { + // Validate the requirements of device. + int attr = 0; + cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_); + if (attr != 1) { + throw std::runtime_error(fmt::format( + "Device {} does not support synchronization in managed memory.", + device_)); + } +} + +void Device::make_current() { + // We need to set/get current CUDA device very frequently, cache it to reduce + // actual calls of CUDA APIs. This function assumes single-thread in host. + static int current = 0; + if (current != device_) { + CHECK_CUDA_ERROR(cudaSetDevice(device_)); + current = device_; + } +} + +DeviceStream& Device::get_stream(Stream s) { + auto it = streams_.find(s.index); + if (it == streams_.end()) { + it = streams_.try_emplace(s.index, *this).first; + } + return it->second; +} + +CommandEncoder::CommandEncoder(DeviceStream& s) + : device_(s.device()), stream_(s) {} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_.add_task(std::move(task)); +} + +void CommandEncoder::end_encoding() { + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + + // There is no kernel running, run completion handlers immediately. + if (!has_gpu_work_) { + worker_.consume_in_this_thread(); + return; + } + has_gpu_work_ = false; + + // Put completion handlers in a batch. + worker_.end_batch(); + + // Signaling kernel completion is expensive, delay until enough batches. + // TODO: This number is arbitrarily picked, profile for a better stragety. + if (worker_.uncommited_batches() > 8) { + commit(); + } +} + +void CommandEncoder::commit() { + worker_.commit(stream_.last_cuda_stream()); +} + +Device& device(mlx::core::Device device) { + static std::unordered_map devices; + auto it = devices.find(device.index); + if (it == devices.end()) { + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +DeviceStream& get_stream(Stream s) { + return device(s.device).get_stream(s); +} + +CommandEncoder& get_command_encoder(Stream s) { + return get_stream(s).get_encoder(); +} + +} // namespace cu + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h new file mode 100644 index 000000000..a65a87d54 --- /dev/null +++ b/mlx/backend/cuda/device.h @@ -0,0 +1,131 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include + +#include + +namespace mlx::core::cu { + +class Device; +class CommandEncoder; + +class DeviceStream { + public: + explicit DeviceStream(Device& device); + + DeviceStream(const DeviceStream&) = delete; + DeviceStream& operator=(const DeviceStream&) = delete; + + // Wait until kernels in the stream complete. + void synchronize(); + + // Return a cuda stream for launching kernels. + cudaStream_t schedule_cuda_stream(); + + // Return the last cuda stream used. + cudaStream_t last_cuda_stream(); + + CommandEncoder& get_encoder(); + + Device& device() { + return device_; + } + + private: + Device& device_; + CudaStream stream_; + std::unique_ptr encoder_; +}; + +class Device { + public: + explicit Device(int device); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, required by some cuda calls. + void make_current(); + + DeviceStream& get_stream(Stream s); + + int cuda_device() const { + return device_; + } + + private: + int device_; + std::unordered_map streams_; +}; + +class CommandEncoder { + public: + explicit CommandEncoder(DeviceStream& stream); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr) {} + void set_output_array(const array& arr) {} + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void end_encoding(); + void commit(); + + // Schedule a cuda stream for |fun| to launch kernels, and check error + // afterwards. + template + void launch_kernel(F&& fun) { + launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); + } + + template + void launch_kernel(cudaStream_t stream, F&& fun) { + device_.make_current(); + fun(stream); + check_cuda_error("kernel launch", cudaGetLastError()); + has_gpu_work_ = true; + } + + Device& device() { + return device_; + } + + DeviceStream& stream() { + return stream_; + } + + bool has_gpu_work() const { + return has_gpu_work_; + } + + private: + Device& device_; + DeviceStream& stream_; + Worker worker_; + bool has_gpu_work_{false}; + std::vector> temporaries_; +}; + +Device& device(mlx::core::Device device); +DeviceStream& get_stream(Stream s); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/dtype_utils.cuh b/mlx/backend/cuda/dtype_utils.cuh new file mode 100644 index 000000000..9b7f8ba65 --- /dev/null +++ b/mlx/backend/cuda/dtype_utils.cuh @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +// Maps CPU types to CUDA types. +template +struct CTypeToCudaType { + using type = T; +}; + +template <> +struct CTypeToCudaType { + using type = __half; +}; + +template <> +struct CTypeToCudaType { + using type = __nv_bfloat16; +}; + +template <> +struct CTypeToCudaType { + using type = cuComplex; +}; + +template +using cuda_type_t = typename CTypeToCudaType::type; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp new file mode 100644 index 000000000..b309ad60e --- /dev/null +++ b/mlx/backend/cuda/eval.cpp @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core::gpu { + +bool is_available() { + return true; +} + +void new_stream(Stream s) { + // Force initalization of cuda, so cuda runtime get destroyed at last. + cudaFree(nullptr); + // Ensure the static stream objects get created. + cu::get_command_encoder(s); + // The main thread is safe to free buffers. + cu::allocator().register_this_thread(); +} + +void eval(array& arr) { + nvtx3::scoped_range r("gpu::eval"); + auto outputs = arr.outputs(); + { + // If the array is a tracer hold a reference + // to its inputs so they don't get donated + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + if (encoder.has_gpu_work()) { + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); + } + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + } + encoder.end_encoding(); +} + +void finalize(Stream s) { + nvtx3::scoped_range r("gpu::finalize"); + cu::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + nvtx3::scoped_range r("gpu::synchronize"); + cu::get_stream(s).synchronize(); +} + +} // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu new file mode 100644 index 000000000..a487f45b4 --- /dev/null +++ b/mlx/backend/cuda/event.cu @@ -0,0 +1,265 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace cu { + +/////////////////////////////////////////////////////////////////////////////// +// CudaEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +// Cuda event managed with RAII. +class CudaEventHandle { + public: + CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags( + &event_, cudaEventDisableTiming | cudaEventBlockingSync)); + } + + ~CudaEventHandle() { + CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + } + + CudaEventHandle(const CudaEventHandle&) = delete; + CudaEventHandle& operator=(const CudaEventHandle&) = delete; + + operator cudaEvent_t() const { + return event_; + } + + private: + cudaEvent_t event_; +}; + +CudaEvent::CudaEvent() : event_(std::make_shared()) {} + +void CudaEvent::wait() { + nvtx3::scoped_range r("cu::CudaEvent::wait"); + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaEventSynchronize(*event_); +} + +void CudaEvent::wait(cudaStream_t stream) { + if (!recorded_) { + throw std::runtime_error("Should not wait on a CudaEvent before record."); + } + cudaStreamWaitEvent(stream, *event_); +} + +void CudaEvent::wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { wait(); }); + } else { + wait(cu::get_stream(s).last_cuda_stream()); + } +} + +void CudaEvent::record(cudaStream_t stream) { + cudaEventRecord(*event_, stream); + recorded_ = true; +} + +void CudaEvent::record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on cpu stream."); + } else { + record(cu::get_stream(s).last_cuda_stream()); + } +} + +bool CudaEvent::completed() const { + return cudaEventQuery(*event_) == cudaSuccess; +} + +/////////////////////////////////////////////////////////////////////////////// +// SharedEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { + uint64_t current; + while ((current = ac->load()) < value) { + ac->wait(current); + } +} + +__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { + ac->store(value); + ac->notify_all(); +} + +__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_wait(ac, value); +} + +__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { + event_signal(ac, value); +} + +} // namespace + +SharedEvent::SharedEvent() { + // Allocate cuda::atomic on managed memory. + allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); + Atomic* ac = static_cast(buffer.raw_ptr()); + new (ac) Atomic(0); + ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ptr->~Atomic(); + allocator::free(buffer); + }); +} + +void SharedEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait"); + event_wait(ac_.get(), value); +} + +void SharedEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +void SharedEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal"); + event_signal(ac_.get(), value); +} + +void SharedEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); +} + +void SharedEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), + [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.add_completed_handler([ac = ac_]() {}); + encoder.end_encoding(); + } +} + +bool SharedEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); + return ac_->load() >= value; +} + +uint64_t SharedEvent::value() const { + nvtx3::scoped_range r("cu::SharedEvent::value"); + return ac_->load(); +} + +} // namespace cu + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + // CudaEvent is preferred when possible because it is fast, however we have + // to fallback to SharedEvent in following cases: + // 1. the event is used to wait/signal a cpu stream; + // 2. signal value other than 1 has been specified. + std::unique_ptr cuda; + std::unique_ptr shared; + + bool is_created() const { + return cuda || shared; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + nvtx3::mark("Using slow SharedEvent"); + shared = std::make_unique(); + } else { + cuda = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(); + } else { + event->shared->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + assert(event->is_created()); + if (event->cuda) { + assert(value() == 1); + event->cuda->wait(s); + } else { + event->shared->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + if (event->cuda) { + assert(value() == 1); + event->cuda->record(s); + } else { + event->shared->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (!event->is_created()) { + return false; + } + if (event->cuda) { + assert(value() == 1); + return event->cuda->recorded() && event->cuda->completed(); + } else { + return event->shared->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h new file mode 100644 index 000000000..4b56e2e3b --- /dev/null +++ b/mlx/backend/cuda/event.h @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/stream.h" + +#include +#include + +#include + +namespace mlx::core::cu { + +class CudaEventHandle; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(); + + void wait(); + void wait(cudaStream_t stream); + void wait(Stream s); + void record(cudaStream_t stream); + void record(Stream s); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + bool recorded() const { + return recorded_; + } + + private: + bool recorded_{false}; + std::shared_ptr event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class SharedEvent { + public: + using Atomic = cuda::atomic; + + SharedEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + const std::shared_ptr& atomic() const { + return ac_; + } + + private: + std::shared_ptr ac_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu new file mode 100644 index 000000000..091b252c1 --- /dev/null +++ b/mlx/backend/cuda/fence.cu @@ -0,0 +1,70 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/event.h" +#include "mlx/fence.h" +#include "mlx/scheduler.h" + +#include + +namespace mlx::core { + +namespace { + +__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { + while (true) { + // In theory the atomic_thread_fence is not needed, but for CUDA 11 without + // it the load() may never return new value. + cuda::atomic_thread_fence(cuda::memory_order_seq_cst); + uint64_t current = ac->load(); + if (current >= value) { + break; + } + } +} + +__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { + busy_wait(ac, value); +} + +} // namespace + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: + // https://github.com/ml-explore/mlx/issues/2137 + const auto& ac = fence->event.atomic(); + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [ac, count = fence->count]() { + nvtx3::scoped_range r("Fence::wait()"); + busy_wait(ac.get(), count); + }); + } else { + nvtx3::scoped_range r("Fence::wait(s)"); + auto& encoder = cu::get_command_encoder(s); + encoder.launch_kernel( + encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { + busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); + }); + encoder.add_completed_handler([ac]() {}); + encoder.end_encoding(); + } +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/kernels/arange.cuh new file mode 100644 index 000000000..53c261e34 --- /dev/null +++ b/mlx/backend/cuda/kernels/arange.cuh @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh new file mode 100644 index 000000000..931c55ff7 --- /dev/null +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -0,0 +1,107 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Missing C++ operator overrides for CUDA 7. +/////////////////////////////////////////////////////////////////////////////// + +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +#define MLX_DEFINE_BF16_OP(OP) \ + __forceinline__ __device__ __nv_bfloat16 operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +#define MLX_DEFINE_BF16_CMP(OP) \ + __forceinline__ __device__ bool operator OP( \ + __nv_bfloat16 x, __nv_bfloat16 y) { \ + return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ + } + +MLX_DEFINE_BF16_OP(+) +MLX_DEFINE_BF16_OP(-) +MLX_DEFINE_BF16_OP(*) +MLX_DEFINE_BF16_OP(/) +MLX_DEFINE_BF16_CMP(>) +MLX_DEFINE_BF16_CMP(<) +MLX_DEFINE_BF16_CMP(>=) +MLX_DEFINE_BF16_CMP(<=) + +#undef MLX_DEFINE_BF16_OP +#undef MLX_DEFINE_BF16_CMP + +#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 + +/////////////////////////////////////////////////////////////////////////////// +// Additional C++ operator overrides between half types and native types. +/////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool is_integral_except = + cuda::std::is_integral_v && !cuda::std::is_same_v; + +template +constexpr bool is_arithmetic_except = + cuda::std::is_arithmetic_v && !cuda::std::is_same_v; + +#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(HALF x, T y) { \ + return FLOAT2HALF(HALF2FLOAT(x) OP static_cast(y)); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ HALF operator OP(T x, HALF y) { \ + return FLOAT2HALF(static_cast(x) OP HALF2FLOAT(y)); \ + } + +#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(HALF x, T y) { \ + return HALF2FLOAT(x) OP static_cast(y); \ + } \ + template < \ + typename T, \ + typename = cuda::std::enable_if_t>> \ + __forceinline__ __device__ bool operator OP(T x, HALF y) { \ + return static_cast(y) OP HALF2FLOAT(x); \ + } + +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *) +MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *) +MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /) +MLX_DEFINE_HALF_CMP(__half, __half2float, <) +MLX_DEFINE_HALF_CMP(__half, __half2float, >) +MLX_DEFINE_HALF_CMP(__half, __half2float, <=) +MLX_DEFINE_HALF_CMP(__half, __half2float, >=) +MLX_DEFINE_HALF_CMP(__half, __half2float, ==) +MLX_DEFINE_HALF_CMP(__half, __half2float, !=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==) +MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=) + +#undef MLX_DEFINE_HALF_OP +#undef MLX_DEFINE_HALF_CMP + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu new file mode 100644 index 000000000..dc6edf606 --- /dev/null +++ b/mlx/backend/cuda/primitives.cu @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/dtype_utils.cuh" +#include "mlx/backend/cuda/kernels/arange.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/distributed/primitives.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + assert(inputs.size() == 0); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_output_array(out); + encoder.launch_kernel([&, this](cudaStream_t stream) { + MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); + }); +} + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CUDA implementation."); \ + } + +NO_GPU(Abs) +NO_GPU(Add) +NO_GPU(AddMM) +NO_GPU(ArcCos) +NO_GPU(ArcCosh) +NO_GPU(ArcSin) +NO_GPU(ArcSinh) +NO_GPU(ArcTan) +NO_GPU(ArcTan2) +NO_GPU(ArcTanh) +NO_GPU(ArgPartition) +NO_GPU(ArgReduce) +NO_GPU(ArgSort) +NO_GPU(BitwiseBinary) +NO_GPU(BitwiseInvert) +NO_GPU(BlockMaskedMM) +NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) +NO_GPU(Conjugate) +NO_GPU(Convolution) +NO_GPU(Cos) +NO_GPU(Cosh) +NO_GPU(Divide) +NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) +NO_GPU(Remainder) +NO_GPU(Equal) +NO_GPU(Erf) +NO_GPU(ErfInv) +NO_GPU(Exp) +NO_GPU(Expm1) +NO_GPU(FFT) +NO_GPU(Floor) +NO_GPU(Gather) +NO_GPU(GatherAxis) +NO_GPU(GatherMM) +NO_GPU(GatherQMM) +NO_GPU(Greater) +NO_GPU(GreaterEqual) +NO_GPU(Hadamard) +NO_GPU(Imag) +NO_GPU(Less) +NO_GPU(LessEqual) +NO_GPU(Load) +NO_GPU(Log) +NO_GPU(Log1p) +NO_GPU(LogicalNot) +NO_GPU(LogicalAnd) +NO_GPU(LogicalOr) +NO_GPU(LogAddExp) +NO_GPU(LogSumExp) +NO_GPU_MULTI(LUF) +NO_GPU(Matmul) +NO_GPU(Maximum) +NO_GPU(Minimum) +NO_GPU(Multiply) +NO_GPU(Negative) +NO_GPU(NotEqual) +NO_GPU(Partition) +NO_GPU(Power) +NO_GPU_MULTI(QRF) +NO_GPU(QuantizedMatmul) +NO_GPU(RandomBits) +NO_GPU(Real) +NO_GPU(Reduce) +NO_GPU(Round) +NO_GPU(Scan) +NO_GPU(Scatter) +NO_GPU(ScatterAxis) +NO_GPU(Select) +NO_GPU(Sigmoid) +NO_GPU(Sign) +NO_GPU(Sin) +NO_GPU(Sinh) +NO_GPU(SliceUpdate) +NO_GPU(Softmax) +NO_GPU(Sort) +NO_GPU(Square) +NO_GPU(Sqrt) +NO_GPU(Subtract) +NO_GPU_MULTI(SVD) +NO_GPU(Tan) +NO_GPU(Tanh) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eigh) + +namespace fast { +NO_GPU_MULTI(LayerNorm) +NO_GPU_MULTI(LayerNormVJP) +NO_GPU_MULTI(RMSNorm) +NO_GPU_MULTI(RMSNormVJP) +NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) +} // namespace fast + +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +NO_GPU_MULTI(Send) +NO_GPU_MULTI(Recv) +} // namespace distributed + +} // namespace mlx::core diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp new file mode 100644 index 000000000..bfa742c74 --- /dev/null +++ b/mlx/backend/cuda/slicing.cpp @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp new file mode 100644 index 000000000..2a11a518e --- /dev/null +++ b/mlx/backend/cuda/utils.cpp @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); +} + +CudaStream::~CudaStream() { + CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); +} + +void check_cuda_error(const char* name, cudaError_t err) { + if (err != cudaSuccess) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, cudaGetErrorString(err))); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h new file mode 100644 index 000000000..58d508765 --- /dev/null +++ b/mlx/backend/cuda/utils.h @@ -0,0 +1,36 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +namespace cu { +class Device; +} + +// Cuda stream managed with RAII. +class CudaStream { + public: + explicit CudaStream(cu::Device& device); + ~CudaStream(); + + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; + + operator cudaStream_t() const { + return stream_; + } + + private: + cudaStream_t stream_; +}; + +// Throw exception if the cuda API does not succeed. +void check_cuda_error(const char* name, cudaError_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) + +} // namespace mlx::core diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp new file mode 100644 index 000000000..64b5c7679 --- /dev/null +++ b/mlx/backend/cuda/worker.cpp @@ -0,0 +1,90 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/worker.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +Worker::Worker() + : signal_stream_(device(mlx::core::Device::gpu)), + worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(worker_mutex_); + stop_ = true; + } + worker_event_.signal(batch_ + 1); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::consume_in_this_thread() { + for (auto& task : pending_tasks_) { + task(); + } + pending_tasks_.clear(); +} + +void Worker::end_batch() { + batch_++; + { + std::lock_guard lock(worker_mutex_); + worker_tasks_[batch_] = std::move(pending_tasks_); + } + uncommited_batches_++; +} + +void Worker::commit() { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + worker_event_.signal(batch_); +} + +void Worker::commit(cudaStream_t stream) { + if (uncommited_batches_ == 0) { + return; + } + uncommited_batches_ = 0; + // Signal the |worker_event_| in |signal_stream_| after the kernels in + // |stream_| finish running. + signal_event_.record(stream); + signal_event_.wait(signal_stream_); + worker_event_.signal(signal_stream_, batch_); +} + +void Worker::thread_fn() { + // The worker thread is safe to free buffers. + allocator().register_this_thread(); + + while (!stop_) { + uint64_t batch = worker_event_.value(); + Tasks tasks; + { + std::lock_guard lock(worker_mutex_); + // Move tasks in signaled batches. + auto end = worker_tasks_.upper_bound(batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + for (auto& task : tasks) { + task(); + } + worker_event_.wait(batch + 1); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h new file mode 100644 index 000000000..d28e22e95 --- /dev/null +++ b/mlx/backend/cuda/worker.h @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" +#include "mlx/backend/cuda/utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + Worker(); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Run pending tasks immediately in current thread. + void consume_in_this_thread(); + + // Put pending tasks in a batch. + void end_batch(); + + // Inform worker thread to run current batches now. + void commit(); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + // Return how many batches have been added but not committed yet. + size_t uncommited_batches() const { + return uncommited_batches_; + } + + private: + void thread_fn(); + + uint64_t batch_{0}; + size_t uncommited_batches_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + // Worker thread. + SharedEvent worker_event_; + std::thread worker_; + std::mutex worker_mutex_; + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; +}; + +} // namespace mlx::core::cu diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cf0ba3d5d..cb174865d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,7 +9,7 @@ FetchContent_MakeAvailable(doctest) add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) -if(MLX_BUILD_METAL) +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) set(METAL_TEST_SOURCES gpu_tests.cpp) endif() From a7fae8a176fad114c89ca66ed0e0be8f3064e3e8 Mon Sep 17 00:00:00 2001 From: ATurker <53705368+aturker1@users.noreply.github.com> Date: Fri, 9 May 2025 20:26:52 +0300 Subject: [PATCH 10/37] fix: conv_general differences between gpu, cpu (#2070) * fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun --- mlx/backend/cpu/conv.cpp | 574 +++++++++++++++++++++---------------- mlx/backend/metal/conv.cpp | 6 +- mlx/ops.cpp | 1 + mlx/primitives.cpp | 48 ++-- mlx/primitives.h | 12 +- python/tests/test_conv.py | 42 +++ 6 files changed, 413 insertions(+), 270 deletions(-) diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index d52f92f8b..e5636b3b8 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -22,7 +22,8 @@ void slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -60,7 +61,8 @@ void slow_conv_1D( out_stride_O = out.strides()[2], flip, - padding = padding[0], + padding_lo = padding_lo[0], + padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { @@ -77,7 +79,7 @@ void slow_conv_1D( const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_stride - padding + wh_flip * wt_dilation; + int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); @@ -109,7 +111,8 @@ void slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -120,230 +123,235 @@ void slow_conv_2D( encoder.set_input_array(wt); encoder.set_output_array(out); - encoder.dispatch([st_wt_ptr = wt.data(), - st_in_ptr = in.data(), - st_out_ptr = out.data(), + encoder.dispatch( + [st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - N = in.shape( - 0), // Batch size, should be the same as out.shape(0) - iH = 1 + - in_dilation[0] * (in.shape(1) - 1), // Input spatial dim - iW = 1 + - in_dilation[1] * (in.shape(2) - 1), // Input spatial dim - C = in.shape(3), // In channels - oH = out.shape(1), // Output spatial dim - oW = out.shape(2), // Output spatial dim - O = wt.shape(0), // Out channels - wH = wt.shape(1), // Weight spatial dim - wW = wt.shape(2), // Weight spatial dim + N = in.shape(0), // Batch size, should be the same as out.shape(0) + iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - groups = in.shape(3) / wt.shape(3), - C_per_group = wt.shape(3), + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - in_stride_N = in.strides()[0], - in_stride_H = in.strides()[1], - in_stride_W = in.strides()[2], - in_stride_C = in.strides()[3], + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - wt_stride_O = wt.strides()[0], - wt_stride_H = wt.strides()[1], - wt_stride_W = wt.strides()[2], - wt_stride_C = wt.strides()[3], + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - out_stride_N = out.strides()[0], - out_stride_H = out.strides()[1], - out_stride_W = out.strides()[2], - out_stride_O = out.strides()[3], + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - padding, - wt_strides, - wt_dilation, - in_dilation, - flip]() mutable { - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - const int O_per_group = O / groups; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int oh, - int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + const int O_per_group = O / groups; + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_wgt_jump_h = - std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = - std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_h = + std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = + std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; - } + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } - base_h[i] = wh_base; - } + base_h[i] = wh_base; + } - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; - } + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - base_w[j] = ww_base; - } + base_w[j] = ww_base; + } - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + + iw_dil * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oH_border_0 = 0; + int oH_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oH; + int oH_border_2 = std::max( + oH_border_1, + (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + int oW_border_0 = 0; + int oW_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oW; + int oW_border_2 = std::max( + oW_border_1, + (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - } // oh + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - } // n - }); + } // n + }); } template @@ -351,7 +359,8 @@ void slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -400,7 +409,8 @@ void slow_conv_3D( out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -415,9 +425,9 @@ void slow_conv_3D( int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; @@ -478,7 +488,7 @@ void slow_conv_3D( std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; + int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { @@ -490,7 +500,7 @@ void slow_conv_3D( } for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; + int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { @@ -502,7 +512,7 @@ void slow_conv_3D( } for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; + int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { @@ -521,9 +531,9 @@ void slow_conv_3D( int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; @@ -573,24 +583,30 @@ void slow_conv_3D( }; int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oD; int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + oD_border_1, + (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oH; int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + oH_border_1, + (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_1 = is_idil_one + ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) + : oW; int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + oW_border_1, + (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -658,7 +674,8 @@ void dispatch_slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -669,7 +686,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -680,7 +698,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -691,7 +710,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -707,7 +727,8 @@ void dispatch_slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -718,7 +739,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -729,7 +751,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -740,7 +763,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -756,7 +780,8 @@ void dispatch_slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -767,7 +792,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -778,7 +804,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -789,7 +816,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = padding[0] * in_padded.strides()[1]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = { + N, + iH + padding_lo[0] + padding_hi[0], + iW + padding_lo[1] + padding_hi[1], + C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = - padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1] + + padding_lo[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, @@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu( Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { - padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); @@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu( // Pick input slice from padded size_t data_offset = 0; - for (size_t i = 0; i < padding.size(); i++) { - data_offset += padding[i] * in_padded.strides()[i + 1]; + for (size_t i = 0; i < padding_lo.size(); i++) { + data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1261,7 +1297,8 @@ void conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1270,22 +1307,40 @@ void conv_1D_cpu( const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation, stream); + in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1295,18 +1350,35 @@ void conv_2D_cpu( if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } - return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_3D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1317,11 +1389,28 @@ void conv_3D_cpu( in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } } // namespace @@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index ae31a6cff..35ed3d44e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aa5e88b7..e8c260425 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3974,6 +3974,7 @@ array conv_general( to_stream(s), stride, padding_lo, + padding_hi, kernel_dilation, input_dilation, groups, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 7288a4885..03ca06bdd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1055,7 +1055,8 @@ array conv_weight_backward_patches( const array& wt, const array& cotan, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, StreamOrDevice s) { // Resolve Padded input shapes and strides Shape padding_starts(in.ndim(), 0); @@ -1064,9 +1065,9 @@ array conv_weight_backward_patches( // padded shape for (int i = 1; i < in.ndim() - 1; i++) { - in_padded_shape[i] += 2 * padding[i - 1]; - padding_ends[i] += padding[i - 1]; - padding_starts[i] += padding[i - 1]; + in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1]; + padding_ends[i] += padding_lo[i - 1]; + padding_starts[i] += padding_lo[i - 1]; } // padded strides (contiguous) @@ -1078,9 +1079,16 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_(padding.begin(), padding.end()); - auto in_padded = pad( - in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); + Shape padding_lo_(padding_lo.begin(), padding_lo.end()); + Shape padding_hi_(padding_hi.begin(), padding_hi.end()); + auto in_padded = + pad(in, + padded_axes, + padding_lo_, + padding_hi_, + array(0, in.dtype()), + "constant", + s); // Resolve strided patches @@ -1147,16 +1155,16 @@ std::vector Convolution::vjp( for (int a : argnums) { // Grads for input if (a == 0) { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_lo[i] = wt_size - padding_[i] - 1; + padding_lo[i] = wt_size - padding_lo_[i] - 1; int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding_[i]; + padding_hi[i] = in_size - out_size + padding_hi_[i]; } // Check for negative padding @@ -1226,18 +1234,12 @@ std::vector Convolution::vjp( if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( - in, wt, cotan, kernel_strides_, padding_, stream()); + in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1283,7 +1285,8 @@ std::pair, std::vector> Convolution::vmap( in, w, kernel_strides_, - padding_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups, @@ -1332,7 +1335,8 @@ std::pair, std::vector> Convolution::vmap( bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); - return padding_ == c_other.padding_ && + return padding_lo_ == c_other.padding_lo_ && + padding_hi_ == c_other.padding_hi_ && kernel_strides_ == c_other.kernel_strides_ && kernel_dilation_ == c_other.kernel_dilation_ && input_dilation_ == c_other.input_dilation_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 3753e43c5..2caed8477 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive { explicit Convolution( Stream stream, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& kernel_dilation, const std::vector& input_dilation, const int groups = 1, const bool flip = false) : UnaryPrimitive(stream), - padding_(padding), + padding_lo_(padding_lo), + padding_hi_(padding_hi), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), @@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive { } private: - std::vector padding_; + std::vector padding_lo_; + std::vector padding_hi_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..35dcf42ac 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + if __name__ == "__main__": unittest.main() From 6661387066b38ef7221d29d7dad6c25d07d6e96a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:25:12 -0700 Subject: [PATCH 11/37] Fix fft for integer overflow (#2161) --- mlx/backend/metal/fft.cpp | 4 +--- mlx/backend/metal/kernels/fft/readwrite.h | 28 ++++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 011eb7ebb..1e23160a6 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -632,7 +632,7 @@ void fft_op( func_consts.push_back(make_int(&rader_m, 3)); // The overall number of FFTs we're going to compute for this input - int size = out.dtype() == float32 ? out.size() : in.size(); + size_t size = out.dtype() == float32 ? out.size() : in.size(); if (real && inverse && four_step_params.required) { size = out.size(); } @@ -659,8 +659,6 @@ void fft_op( // We can perform 2 RFFTs at once so the batch size is halved. batch_size = (batch_size + 2 - 1) / 2; } - int out_buffer_size = out.size(); - auto& compute_encoder = d.get_command_encoder(s.index); auto in_type_str = in.dtype() == float32 ? "float" : "float2"; auto out_type_str = out.dtype() == float32 ? "float" : "float2"; diff --git a/mlx/backend/metal/kernels/fft/readwrite.h b/mlx/backend/metal/kernels/fft/readwrite.h index f6724820d..0dc62992e 100644 --- a/mlx/backend/metal/kernels/fft/readwrite.h +++ b/mlx/backend/metal/kernels/fft/readwrite.h @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; From 659a51919fd3d70798e91e9e112075680b95556e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 May 2025 14:35:14 -0700 Subject: [PATCH 12/37] patch bump (#2162) --- mlx/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/version.h b/mlx/version.h index 8340e1e8c..c573c45c9 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -4,7 +4,7 @@ #define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MINOR 25 -#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_PATCH 2 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) From caaa3f1f8ceac3faee5068c04ea0e574af24f829 Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Sun, 11 May 2025 15:03:47 +0200 Subject: [PATCH 13/37] Small typos in mx.metal deprecations (#2176) --- python/src/metal.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/src/metal.cpp b/python/src/metal.cpp index a13dd2a03..54642409c 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -49,21 +49,21 @@ void init_metal(nb::module_& m) { metal.def( "set_memory_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_memory_limt", "mx.set_memory_limit"); + DEPRECATE("mx.metal.set_memory_limit", "mx.set_memory_limit"); return mx::set_memory_limit(limit); }, "limit"_a); metal.def( "set_cache_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_cache_limt", "mx.set_cache_limit"); + DEPRECATE("mx.metal.set_cache_limit", "mx.set_cache_limit"); return mx::set_cache_limit(limit); }, "limit"_a); metal.def( "set_wired_limit", [](size_t limit) { - DEPRECATE("mx.metal.set_wired_limt", "mx.set_wired_limit"); + DEPRECATE("mx.metal.set_wired_limit", "mx.set_wired_limit"); return mx::set_wired_limit(limit); }, "limit"_a); From 8f3d208dcef00c5085dd3acfde2a6abb18585f07 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 12 May 2025 10:48:57 -0700 Subject: [PATCH 14/37] Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177) * handle hadamard and addmm on empty inputs * fix --- mlx/backend/cpu/matmul.cpp | 8 +++++++- mlx/backend/metal/matmul.cpp | 17 +++++++++++++++++ mlx/ops.cpp | 8 ++++++++ python/tests/test_blas.py | 17 +++++++++++++++++ python/tests/test_ops.py | 3 +++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 8ae99ab2d..b944aacc0 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[AddMM::eval_cpu] Currently only supports float32."); } + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } // Fill output with C auto& c = inputs[2]; @@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector& inputs, array& out) { ? CopyType::Scalar : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); copy(c, out, ctype, stream()); - + if (inputs[0].shape(-1) == 0) { + return; + } matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 71221f8d9..e0ff44200 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -716,6 +716,23 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } + + // Return 0s if either input is empty + if (out.size() == 0) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + + // Copy c into out and return + if (inputs[0].shape(-1) == 0) { + copy_gpu( + inputs[2], + out, + inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + return; + } + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e8c260425..922680110 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -472,6 +472,10 @@ array hadamard_transform( const array& a, std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { + if (a.size() == 0) { + throw std::invalid_argument( + "[hadamard_transform] Does not support empty arrays."); + } // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) int n = a.ndim() > 0 ? a.shape(-1) : 1; float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(n); @@ -4326,6 +4330,10 @@ array addmm( c = reshape(c, c_reshape, s); } + if (c.shape() != out_shape) { + throw std::invalid_argument( + "[addmm] input c must broadcast to the output shape"); + } auto out = array( std::move(out_shape), diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 6fca4885b..df459eadc 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -589,6 +589,10 @@ class TestBlas(mlx_tests.MLXTestCase): alpha = 0.5 beta = 2.0 + # c must broadcast to the output shape + with self.assertRaises(ValueError): + mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2))) + # Regular batched case a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) @@ -745,6 +749,19 @@ class TestBlas(mlx_tests.MLXTestCase): mx.eval(c) self.assertEqual(c.shape, (0, 0)) + c = mx.array(1.0, dtype=mx.float32) + a = mx.array([], dtype=mx.float32) + b = mx.array([], dtype=mx.float32) + out = mx.addmm(c, a, b) + self.assertEqual(out.item(), 1.0) + self.assertEqual(out.shape, ()) + + a = mx.zeros(shape=(5, 0)) + b = mx.zeros(shape=(0, 5)) + c = mx.random.uniform(shape=(5, 5)) + out = mx.addmm(c, a, b) + self.assertTrue(mx.allclose(out, c)) + def test_block_masked_matmul(self): def ref_block_masked_mm( a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d9e143d82..0921de788 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2830,6 +2830,9 @@ class TestOps(mlx_tests.MLXTestCase): return H def test_hadamard(self): + with self.assertRaises(ValueError): + mx.hadamard_transform(mx.array([])) + h28_str = """ +------++----++-+--+-+--++-- -+-----+++-----+-+--+-+--++- From 3aa9cf3f9ed7e1dd508b0d98b07834f5ac5c43cf Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 14:27:53 -0700 Subject: [PATCH 15/37] Fix put_along_axis for empty arrays (#2181) --- mlx/ops.cpp | 4 ++++ python/tests/test_ops.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 922680110..0c18cccfe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3175,6 +3175,10 @@ array scatter_axis( throw std::invalid_argument(msg.str()); } + if (a.size() == 0) { + return a; + } + auto upd = astype(values, a.dtype(), s); // Squeeze leading singletons out of update diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 0921de788..f3d48dda3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1255,6 +1255,12 @@ class TestOps(mlx_tests.MLXTestCase): np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2) self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + a = mx.array([], mx.float32) + b = mx.put_along_axis(a, a, a, axis=None) + mx.eval(b) + self.assertEqual(b.size, 0) + self.assertEqual(b.shape, a.shape) + def test_split(self): a = mx.array([1, 2, 3]) splits = mx.split(a, 3) From eca2f3eb974b86d37da170023040c5ac9a148c18 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 14 May 2025 09:09:56 +0900 Subject: [PATCH 16/37] Add remove_index utility (#2173) --- mlx/backend/common/utils.h | 7 +++++++ mlx/backend/cpu/arg_reduce.cpp | 6 ++---- mlx/backend/cpu/indexing.cpp | 28 ++++++++++------------------ mlx/backend/metal/indexing.cpp | 29 +++++++---------------------- 4 files changed, 26 insertions(+), 44 deletions(-) diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 20a65d7b1..a4bdaa5ca 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -165,4 +165,11 @@ void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out); + +template +inline std::vector remove_index(std::vector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index a8ba3efe2..66468912d 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -14,10 +14,8 @@ template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; - Strides strides = in.strides(); - Shape shape = in.shape(); - strides.erase(strides.begin() + axis); - shape.erase(shape.begin() + axis); + Strides strides = remove_index(in.strides(), axis); + Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); auto out_ptr = out.data(); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 70d6b3eb7..5f99093e5 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -257,15 +257,11 @@ void gather_axis( const array& ind, array& out, const int axis) { - auto strides = ind.strides(); - strides.erase(strides.begin() + axis); - auto shape = ind.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator ind_it(shape, strides, src.ndim() - 1); - - strides = src.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator src_it(shape, strides, src.ndim() - 1); + auto shape = remove_index(ind.shape(), axis); + ContiguousIterator ind_it( + shape, remove_index(ind.strides(), axis), src.ndim() - 1); + ContiguousIterator src_it( + shape, remove_index(src.strides(), axis), src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); @@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { template void scatter_axis(array& out, const array idx, const array& upd, int axis) { - auto strides = idx.strides(); - strides.erase(strides.begin() + axis); - auto shape = idx.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); - - strides = upd.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); + auto shape = remove_index(idx.shape(), axis); + ContiguousIterator idx_it( + shape, remove_index(idx.strides(), axis), upd.ndim() - 1); + ContiguousIterator upd_it( + shape, remove_index(upd.strides(), axis), upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cccfd908a..d2a601b1e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,6 +2,7 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" @@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = src.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(src.shape(axis_), 8); @@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = upd.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8); From 0751263dec5a210eb2ba097c108e8d78aa58124c Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 14 May 2025 12:19:54 +0900 Subject: [PATCH 17/37] Fix typo in row_reduce_small (#2179) --- mlx/backend/metal/kernels/reduction/reduce_row.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index c8973429f..936d75bb5 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { From 130df35e1b520061a053c052fba07122dc390c6a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 13 May 2025 22:43:45 -0700 Subject: [PATCH 18/37] Add random normal distribution for complex numbers (#2182) --- mlx/random.cpp | 45 +++++++++++++++++++++++++++++-------- mlx/random.h | 18 ++++++++++++--- python/src/random.cpp | 35 +++++++++++++++++++---------- python/tests/test_random.py | 35 +++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 24 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 89a027b17..6c6d1eb95 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -176,24 +176,51 @@ array uniform( array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); } +inline array complex_normal( + Shape shape, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s) { + auto stream = to_stream(s); + auto low = above_minus_one_with_default(float32); + auto high = array(1.0f, float32); + shape.push_back(2); + auto samples = + erfinv(uniform(low, high, shape, float32, key, stream), stream); + samples = squeeze(view(samples, complex64, stream), -1, stream); + if (scale.has_value()) { + samples = multiply(*scale, samples, stream); + } + if (loc.has_value()) { + samples = add(*loc, samples, stream); + } + return samples; +} + array normal( const Shape& shape, Dtype dtype, - const float loc /* = 0.0 */, - const float scale /* = 1.0 */, - const std::optional& key /*= nullopt */, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, StreamOrDevice s /* = {} */) { + if (dtype == complex64) { + return complex_normal(shape, loc, scale, key, s); + } + auto stream = to_stream(s); auto low = above_minus_one_with_default(dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); - samples = - multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); - if (scale != 1.0) { - samples = multiply(array(scale, dtype), samples, stream); + auto applied_scale = array(std::sqrt(2.0), dtype); + if (scale.has_value()) { + applied_scale = + multiply(applied_scale, astype(*scale, dtype, stream), stream); } - if (loc != 0.0) { - samples = add(array(loc, dtype), samples, stream); + samples = multiply(applied_scale, erfinv(samples, stream), stream); + if (loc.has_value()) { + samples = add(astype(*loc, dtype, stream), samples, stream); } return samples; } diff --git a/mlx/random.h b/mlx/random.h index b2c821736..0dfdab7a1 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -94,12 +94,24 @@ inline array uniform( /** Generate samples from the standard normal distribution. */ array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, - StreamOrDevice s = {}); + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} inline array normal( const Shape& shape, const float loc, @@ -113,13 +125,13 @@ inline array normal( const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, dtype, 0.0, 1.0, key, s); + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); } inline array normal( const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { - return normal(shape, float32, 0.0, 1.0, key, s); + return normal(shape, float32, std::nullopt, std::nullopt, key, s); } /** Generate samples from a multivariate normal distribution. **/ diff --git a/python/src/random.cpp b/python/src/random.cpp index 22b706174..837f91616 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -152,31 +152,42 @@ void init_random(nb::module_& parent_module) { "normal", [](const mx::Shape& shape, std::optional type, - float loc, - float scale, + const std::optional& loc_, + const std::optional& scale_, const std::optional& key_, mx::StreamOrDevice s) { + auto dtype = type.value_or(mx::float32); auto key = key_ ? key_.value() : default_key().next(); - return mx::random::normal( - shape, type.value_or(mx::float32), loc, scale, key, s); + auto loc = + loc_ ? std::make_optional(to_array(*loc_, dtype)) : std::nullopt; + auto scale = scale_ ? std::make_optional(to_array(*scale_, dtype)) + : std::nullopt; + return mx::random::normal(shape, dtype, loc, scale, key, s); }, "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, - "loc"_a = 0.0, - "scale"_a = 1.0, + "loc"_a = nb::none(), + "scale"_a = nb::none(), "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Generate normally distributed random numbers. + If ``loc`` and ``scale`` are not provided the "standard" normal + distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for + real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0, + \frac{1}{2})$ for complex numbers. + Args: - shape (list(int), optional): Shape of the output. Default is ``()``. - dtype (Dtype, optional): Type of the output. Default is ``float32``. - loc (float, optional): Mean of the distribution. Default is ``0.0``. - scale (float, optional): Standard deviation of the distribution. Default is ``1.0``. - key (array, optional): A PRNG key. Default: None. + shape (list(int), optional): Shape of the output. Default: ``()``. + dtype (Dtype, optional): Type of the output. Default: ``float32``. + loc (scalar or array, optional): Mean of the distribution. + Default: ``None``. + scale (scalar or array, optional): Standard deviation of the + distribution. Default: ``None``. + key (array, optional): A PRNG key. Default: ``None``. Returns: array: The output array of random values. diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 9efbfb5f6..2fc768651 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -352,6 +352,41 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(mx.array([[1]])) self.assertEqual(x.shape, (1, 1)) + def test_complex_normal(self): + sample = mx.random.normal(tuple(), dtype=mx.complex64) + self.assertEqual(sample.shape, tuple()) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal((1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + sample = mx.random.normal( + (1, 2, 3, 4), dtype=mx.complex64, scale=2.0, loc=3.0 + 1j + ) + self.assertEqual(sample.shape, (1, 2, 3, 4)) + self.assertEqual(sample.dtype, mx.complex64) + + def test_broadcastable_scale_loc(self): + b = mx.random.normal((10, 2)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (2, 10, 2)) + + with self.assertRaises(ValueError): + b = mx.random.normal((10,)) + sample = mx.random.normal((2, 10, 2), loc=b, scale=b) + + b = mx.random.normal((3, 1, 2)) + sample = mx.random.normal((3, 4, 2), dtype=mx.float16, loc=b, scale=b) + mx.eval(sample) + self.assertEqual(sample.shape, (3, 4, 2)) + self.assertEqual(sample.dtype, mx.float16) + if __name__ == "__main__": unittest.main() From cf6c939e868f6db3421396fda3fde31708e6f1eb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 14 May 2025 23:37:12 -0700 Subject: [PATCH 19/37] Fix some complex vjps (#2178) --- mlx/primitives.cpp | 89 ++++++++++++++++++++++++++++++---------- python/tests/test_fft.py | 57 +++++++++++++++++++++++++ tests/autograd_tests.cpp | 46 +++++++++++++++------ tests/fft_tests.cpp | 16 ++++---- 4 files changed, 166 insertions(+), 42 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 03ca06bdd..e1924e66c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1488,14 +1488,16 @@ std::vector Divide::vjp( const std::vector& argnums, const std::vector&) { std::vector vjps; + array denominator_bar = conjugate(primals[1], stream()); for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotangents[0], primals[1], stream())); + vjps.push_back(divide(cotangents[0], denominator_bar, stream())); } else { vjps.push_back(negative( divide( - multiply(cotangents[0], primals[0], stream()), - square(primals[1], stream()), + multiply( + cotangents[0], conjugate(primals[0], stream()), stream()), + square(denominator_bar, stream()), stream()), stream())); } @@ -1950,30 +1952,74 @@ std::vector FFT::vjp( assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); + + // TODO: Add it as an option to do an unnormalized or scaled fft so that this + // isn't part of the graph. + double n_elements = 1; + for (auto ax : axes) { + n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); + } + if (real_ && inverse_) { - auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = Shape(out.ndim(), 0); - auto stop = in.shape(); - out = slice(out, start, stop, stream()); - auto mask_shape = out.shape(); - mask_shape[axes_.back()] -= 2; - auto mask = full(mask_shape, 2.0f, stream()); - auto pad_shape = out.shape(); - pad_shape[axes_.back()] = 1; - auto pad = full(pad_shape, 1.0f, stream()); - mask = concatenate({pad, mask, pad}, axes_.back(), stream()); - return {multiply(mask, out, stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets doubled. + int N = in.shape(axes_.back()); + bool odd = cotangents[0].shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + two, + one, + stream()); + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; } else if (real_) { Shape n; for (auto ax : axes_) { - n.push_back(in.shape()[ax]); + n.push_back(in.shape(ax)); } - return {astype( - fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + int N = cotangents[0].shape(axes_.back()); + bool odd = in.shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + half, + one, + stream()); + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { - return {fft::ifftn(cotangents[0], axes, stream())}; + return {multiply( + fft::fftn(cotangents[0], axes, stream()), + array(1 / n_elements, complex64), + stream())}; } else { - return {fft::fftn(cotangents[0], axes, stream())}; + return {multiply( + fft::ifftn(cotangents[0], axes, stream()), + array(n_elements, complex64), + stream())}; } } @@ -2776,7 +2822,8 @@ std::vector Multiply::vjp( const std::vector&) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); + vjps.push_back(multiply( + conjugate(primals[1 - arg], stream()), cotangents[0], stream())); } return vjps; } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index f644944c7..df9d25edc 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -7,6 +7,13 @@ import mlx.core as mx import mlx_tests import numpy as np +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + class TestFFT(mlx_tests.MLXTestCase): def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): @@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase): x = mx.array([]) self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_fft_grads(self): + real = [True, False] + inverse = [True, False] + axes = [ + (-1,), + (-2, -1), + ] + shapes = [ + (4, 4), + (2, 4), + (2, 7), + (7, 7), + ] + + mxffts = { + (True, True): mx.fft.irfftn, + (True, False): mx.fft.rfftn, + (False, True): mx.fft.ifftn, + (False, False): mx.fft.fftn, + } + tffts = { + (True, True): torch.fft.irfftn, + (True, False): torch.fft.rfftn, + (False, True): torch.fft.ifftn, + (False, False): torch.fft.fftn, + } + + for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): + + def f(x): + y = mxffts[r, i](x) + return (mx.abs(y) ** 2).sum() + + def g(x): + y = tffts[r, i](x) + return (torch.abs(y) ** 2).sum() + + if r and not i: + x = mx.random.normal(sh) + else: + x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + fx = f(x) + gx = g(torch.tensor(x)) + self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) + + dfdx = mx.grad(f)(x) + dgdx = torch.func.grad(g)(torch.tensor(x)) + self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..5b3454bfc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") { } { + auto multiply_fn = + [](const std::vector& inputs) -> std::vector { + return {multiply(inputs[0], inputs[1])}; + }; + // Compute jvp auto x = array(complex64_t{2.0, 4.0}); auto y = array(3.0f); - auto x_tan = array(complex64_t{1.0, 2.0}); auto y_tan = array(2.0f); + auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{7.0, 14.0}); - auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; - CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); - - out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; - CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); - + // Compute vjp auto cotan = array(complex64_t{2.0, 3.0}); - out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; - CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), -8.0); + auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].dtype(), complex64); + CHECK_EQ(vjp_out[0].item(), complex64_t{6.0, 9.0}); + CHECK_EQ(vjp_out[1].dtype(), float32); + CHECK_EQ(vjp_out[1].item(), 16); + } - out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; - CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + { + auto divide_fn = + [](const std::vector& inputs) -> std::vector { + return {divide(inputs[0], inputs[1])}; + }; + + // Compute jvp + auto x = array(complex64_t{2.0, 3.0}); + auto y = array(complex64_t{1.0, 2.0}); + auto x_tan = array(complex64_t{3.0, 4.0}); + auto y_tan = array(complex64_t{4.0, -2.0}); + auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ( + jvp_out[0].item(), doctest::Approx(complex64_t{2.6, 2.8})); + + // Compute vjp + auto cotan = array(complex64_t{2.0, -4.0}); + auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{2.0, 0.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{-3.2, -0.4}); } } diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 0db3999c8..b9e2d1bcc 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -243,7 +243,7 @@ TEST_CASE("test fft grads") { auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item()); auto tangent = astype(arange(10), complex64); auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; @@ -252,7 +252,7 @@ TEST_CASE("test fft grads") { // Inverse auto ifft_fn = [](array x) { return fft::ifft(x); }; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item()); jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); @@ -261,7 +261,8 @@ TEST_CASE("test fft grads") { auto rfft_fn = [](array x) { return fft::rfft(x); }; cotangent = astype(arange(6), complex64); vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; - auto expected = astype(fft::fft(cotangent, 10, 0), float32); + array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64); + auto expected = fft::irfft(cotangent * mask, 10, 0) * 10; CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), float32); @@ -272,12 +273,9 @@ TEST_CASE("test fft grads") { auto irfft_fn = [](array x) { return fft::irfft(x); }; cotangent = astype(arange(10), float32); vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; - expected = fft::fft(cotangent, 10, 0); - auto o_splits = split(vjp_out, {1, 5}); - auto e_splits = split(expected, {1, 5, 6}); - CHECK_EQ(e_splits[0].item(), o_splits[0].item()); - CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); - CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32); + expected = fft::rfft(cotangent) * mask; + CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), complex64); jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second; From c1eb9d05d98a16e1e22f5c9b5c683d50c4188e54 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 13:01:44 -0700 Subject: [PATCH 20/37] non-symmetric eig and eigh (#2188) --- docs/src/python/linalg.rst | 2 + mlx/backend/cpu/CMakeLists.txt | 1 + mlx/backend/cpu/eig.cpp | 174 ++++++++++++++++++++++++++++++ mlx/backend/cpu/lapack.h | 1 + mlx/backend/metal/primitives.cpp | 8 +- mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_gpu/primitives.cpp | 1 + mlx/export.cpp | 1 + mlx/linalg.cpp | 26 ++++- mlx/linalg.h | 4 + mlx/primitives.cpp | 37 +++++++ mlx/primitives.h | 23 ++++ python/src/linalg.cpp | 72 ++++++++++++- python/tests/test_linalg.py | 77 +++++++++++++ 14 files changed, 423 insertions(+), 5 deletions(-) create mode 100644 mlx/backend/cpu/eig.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index b01f74117..495380c46 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,6 +16,8 @@ Linear Algebra cross qr svd + eigvals + eig eigvalsh eigh lu diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 96b3f1313..9d322c4c4 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -46,6 +46,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/cpu/eig.cpp b/mlx/backend/cpu/eig.cpp new file mode 100644 index 000000000..c89003fc0 --- /dev/null +++ b/mlx/backend/cpu/eig.cpp @@ -0,0 +1,174 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void eig_impl( + array& a, + array& vectors, + array& values, + bool compute_eigenvectors, + Stream stream) { + using OT = std::complex; + auto a_ptr = a.data(); + auto eig_ptr = values.data(); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(values); + OT* vec_ptr = nullptr; + if (compute_eigenvectors) { + encoder.set_output_array(vectors); + vec_ptr = vectors.data(); + } + encoder.dispatch([a_ptr, + vec_ptr, + eig_ptr, + compute_eigenvectors, + N = vectors.shape(-1), + size = vectors.size()]() mutable { + // Work query + char jobr = 'N'; + char jobl = compute_eigenvectors ? 'V' : 'N'; + int n_vecs_r = 1; + int n_vecs_l = compute_eigenvectors ? N : 1; + int lwork = -1; + int info; + { + T work; + int iwork; + geev( + &jobl, + &jobr, + &N, + nullptr, + &N, + nullptr, + nullptr, + nullptr, + &n_vecs_l, + nullptr, + &n_vecs_r, + &work, + &lwork, + &info); + lwork = static_cast(work); + } + + auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; + auto vec_tmp_data = + array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)}; + auto eig_tmp = static_cast(eig_tmp_data.buffer.raw_ptr()); + auto vec_tmp = static_cast(vec_tmp_data.buffer.raw_ptr()); + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + for (size_t i = 0; i < size / (N * N); ++i) { + geev( + &jobl, + &jobr, + &N, + a_ptr, + &N, + eig_tmp, + eig_tmp + N, + vec_tmp, + &n_vecs_l, + nullptr, + &n_vecs_r, + static_cast(work_buf.buffer.raw_ptr()), + &lwork, + &info); + for (int i = 0; i < N; ++i) { + eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]}; + } + if (vec_ptr) { + for (int i = 0; i < N; ++i) { + if (eig_ptr[i].imag() != 0) { + // This vector and the next are a pair + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = { + vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]}; + vec_ptr[(i + 1) * N + j] = { + vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]}; + } + i += 1; + } else { + for (int j = 0; j < N; ++j) { + vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0}; + } + } + } + vec_ptr += N * N; + } + a_ptr += N * N; + eig_ptr += N; + if (info != 0) { + std::stringstream msg; + msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + } + }); + encoder.add_temporary(a); +} + +} // namespace + +void Eig::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), complex64, nullptr, {}); + + auto a_copy = array(a.shape(), a.dtype(), nullptr, {}); + copy( + a, + a_copy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); + + values.set_data(allocator::malloc(values.nbytes())); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim - 1], strides[ndim - 2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.set_data( + allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags); + } + switch (a.dtype()) { + case float32: + eig_impl(a_copy, vectors, values, compute_eigenvectors_, stream()); + break; + default: + throw std::runtime_error("[Eig::eval_cpu] only supports float32."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 2911c63f8..411742d56 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -45,6 +45,7 @@ INSTANTIATE_LAPACK_TYPES(geqrf) INSTANTIATE_LAPACK_TYPES(orgqr) INSTANTIATE_LAPACK_TYPES(syevd) +INSTANTIATE_LAPACK_TYPES(geev) INSTANTIATE_LAPACK_TYPES(potrf) INSTANTIATE_LAPACK_TYPES(gesvdx) INSTANTIATE_LAPACK_TYPES(getrf) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 860e9ddd7..6e42b29c9 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -378,10 +378,16 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eig::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Eig::eval_gpu] Metal Eig NYI."); +} + void Eigh::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); + throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } void LUF::eval_gpu( diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 84372b096..1a180bfe0 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -55,6 +55,7 @@ NO_CPU(DynamicSlice) NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(Eig) NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 6826c97f6..676a6e550 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -126,6 +126,7 @@ NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) +NO_GPU_MULTI(Eig) NO_GPU(View) namespace fast { diff --git a/mlx/export.cpp b/mlx/export.cpp index c9139e156..bd2f24ba2 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -331,6 +331,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(SVD), SERIALIZE_PRIMITIVE(Inverse), SERIALIZE_PRIMITIVE(Cholesky), + SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), SERIALIZE_PRIMITIVE(AffineQuantize), SERIALIZE_PRIMITIVE(RMSNorm), diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 53f13486a..e0f4ec2e6 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -488,7 +488,7 @@ array cross( return concatenate(outputs, axis, s); } -void validate_eigh( +void validate_eig( const array& a, const StreamOrDevice& stream, const std::string fname) { @@ -511,7 +511,7 @@ array eigvalsh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigvalsh]"); + validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); return array( std::move(out_shape), @@ -524,7 +524,7 @@ std::pair eigh( const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { - validate_eigh(a, s, "[linalg::eigh]"); + validate_eig(a, s, "[linalg::eigh]"); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, @@ -533,6 +533,26 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +array eigvals(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eigvals]"); + Shape out_shape(a.shape().begin(), a.shape().end() - 1); + return array( + std::move(out_shape), + complex64, + std::make_shared(to_stream(s), false), + {a}); +} + +std::pair eig(const array& a, StreamOrDevice s /* = {} */) { + validate_eig(a, s, "[linalg::eig]"); + auto out = array::make_arrays( + {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {complex64, complex64}, + std::make_shared(to_stream(s), true), + {a}); + return std::make_pair(out[0], out[1]); +} + void validate_lu( const array& a, const StreamOrDevice& stream, diff --git a/mlx/linalg.h b/mlx/linalg.h index 8c3a2070a..0690fba95 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -99,6 +99,10 @@ array cross( int axis = -1, StreamOrDevice s = {}); +std::pair eig(const array& a, StreamOrDevice s = {}); + +array eigvals(const array& a, StreamOrDevice s = {}); + array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e1924e66c..87b2bc924 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -875,6 +875,43 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eig::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eig(a, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvals(a, stream())}; + } + + return {outputs, std::vector(outputs.size(), ax)}; +} + +std::vector Eig::output_shapes(const std::vector& inputs) { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return { + std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {std::move(shape)}; // Only eigenvalues + } +} + +bool Eig::is_equivalent(const Primitive& other) const { + auto& e_other = static_cast(other); + return compute_eigenvectors_ == e_other.compute_eigenvectors_; +} + std::pair, std::vector> Eigh::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 2caed8477..c0fbfc84d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2381,6 +2381,29 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + class Eigh : public Primitive { public: explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 3bc0e5b1b..cc8e79db6 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -236,7 +236,7 @@ void init_linalg(nb::module_& parent_module) { Returns: Union[tuple(array, ...), array]: - If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( @@ -407,6 +407,76 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvals", + &mx::linalg::eigvals, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues of a square matrix. + + This function differs from :func:`numpy.linalg.eigvals` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues (not necessarily in order). + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu) + >>> eigenvalues + array([3+0j, -1+0j], dtype=complex64) + )pbdoc"); + m.def( + "eig", + [](const mx::array& a, mx::StreamOrDevice s) { + auto result = mx::linalg::eig(a, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a square matrix. + + This function differs from :func:`numpy.linalg.eig` in that the + return type is always complex even if the eigenvalues are all real. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. + + Args: + a (array): The input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: + A tuple containing the eigenvalues and the normalized right + eigenvectors. The column ``v[:, i]`` is the eigenvector + corresponding to the i-th eigenvalue. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eig(A, stream=mx.cpu) + >>> w + array([3+0j, -1+0j], dtype=complex64) + >>> v + array([[0.707107+0j, 0.707107+0j], + [-0.707107+0j, 0.707107+0j]], dtype=complex64) + )pbdoc"); + m.def( "eigvalsh", &mx::linalg::eigvalsh, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index a9fe572af..f65da1ff7 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -312,6 +312,83 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eig(self): + tols = {"atol": 1e-5, "rtol": 1e-5} + + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eig(A, stream=mx.cpu, **kwargs) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) + eig_vals_only = mx.linalg.eigvals(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 matrix + A_np = np.array([[1.0, 1.0], [3.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test complex eigenvalues + A_np = np.array([[1.0, -1.0], [1.0, 1.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) + + # Test a larger random symmetric matrix + n = 5 + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test with batched input + A_np = np.random.randn(3, n, n).astype(np.float32) + check_eigs_and_vecs(A_np) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eig(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eig( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + with self.assertRaises(ValueError): + mx.linalg.eigvals(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvals( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix + + def test_lu(self): + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array(0.0), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) + + # Test 3x3 matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + # Test batch dimension + a = mx.broadcast_to(a, (5, 5, 3, 3)) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + L = mx.take_along_axis(L, P[..., None], axis=-2) + self.assertTrue(mx.allclose(L @ U, a)) + + # Test non-square matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5} From a2cadb8218a6b350557a1a06954b65834e6cd446 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 18:17:50 -0700 Subject: [PATCH 21/37] real and imag properties (#2189) --- docs/src/python/array.rst | 2 ++ python/src/array.cpp | 12 ++++++++++++ python/tests/test_array.py | 9 +++++++++ 3 files changed, 23 insertions(+) diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 7e1c3339d..e68524d5a 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -19,6 +19,8 @@ Array array.ndim array.shape array.size + array.real + array.imag array.abs array.all array.any diff --git a/python/src/array.cpp b/python/src/array.cpp index 5f8dbe021..5ba0aaedc 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -319,6 +319,18 @@ void init_array(nb::module_& m) { R"pbdoc( The array's :class:`Dtype`. )pbdoc") + .def_prop_ro( + "real", + [](const mx::array& a) { return mx::real(a); }, + R"pbdoc( + The real part of a complex array. + )pbdoc") + .def_prop_ro( + "imag", + [](const mx::array& a) { return mx::imag(a); }, + R"pbdoc( + The imaginary part of a complex array. + )pbdoc") .def( "item", &to_scalar, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 792e666d6..e63da17df 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2022,6 +2022,15 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.add(y, x) + def test_real_imag(self): + x = mx.array([1.0]) + self.assertEqual(x.real.item(), 1.0) + self.assertEqual(x.imag.item(), 0.0) + + x = mx.array([1.0 + 1.0j]) + self.assertEqual(x.imag.item(), 1.0) + self.assertEqual(x.real.item(), 1.0) + if __name__ == "__main__": unittest.main() From 602f43e3d1f75a1036a3008024afa8f27c3140d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 19:20:36 -0700 Subject: [PATCH 22/37] fix conv grad (#2187) --- mlx/primitives.cpp | 18 +++++++++++------- python/tests/test_conv.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 87b2bc924..c2bb59c05 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1116,13 +1116,11 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_lo_(padding_lo.begin(), padding_lo.end()); - Shape padding_hi_(padding_hi.begin(), padding_hi.end()); auto in_padded = pad(in, padded_axes, - padding_lo_, - padding_hi_, + Shape(padding_lo), + Shape(padding_hi), array(0, in.dtype()), "constant", s); @@ -1274,8 +1272,14 @@ std::vector Convolution::vjp( in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_lo_; - std::vector padding_hi = padding_hi_; + auto padding_hi = padding_lo_; + + for (int i = 0; i < padding_hi.size(); ++i) { + int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); + int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1; + } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1284,7 +1288,7 @@ std::vector Convolution::vjp( /* const array& input = */ in_trans, /* const array& weight = */ cotan_trans, /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_lo = */ padding_lo_, /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_strides_, /* std::vector input_dilation = */ input_dilation_, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 35dcf42ac..7d63e4751 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1130,6 +1130,28 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + def test_basic_grad_shapes(self): + def loss_fn(kernel, inputs, strides, groups): + return mx.sum( + mx.conv_general( + inputs, + kernel, + stride=strides, + groups=groups, + ) + ) + + for in_shape, k_shape, strides, groups in [ + ((3, 5, 4), (6, 2, 2), (2,), 2), + ((3, 5, 4), (24, 2, 1), (2,), 4), + ((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2), + ((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4), + ]: + grads = mx.grad(loss_fn)( + mx.zeros(k_shape), mx.zeros(in_shape), strides, groups + ) + self.assertEqual(grads.shape, k_shape) + if __name__ == "__main__": unittest.main() From 7ff5c41e061a27265e0fe793dfc5dda3f4b55e46 Mon Sep 17 00:00:00 2001 From: Jack Wind Date: Fri, 16 May 2025 03:28:03 -0400 Subject: [PATCH 23/37] Add set_threadgroup_memory_length to CommandEncoder (#2183) --- mlx/backend/metal/device.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 26c9a0a28..660ba65e2 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -95,6 +95,10 @@ struct CommandEncoder { return enc_->setBytes(&v, sizeof(T), idx); } + void set_threadgroup_memory_length(size_t length, int idx) { + enc_->setThreadgroupMemoryLength(length, idx); + } + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } From 7d4b378952489b5c19b8d3ca5c028bf46a6ae86c Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 16 May 2025 22:44:42 +0900 Subject: [PATCH 24/37] Include cuda_bf16.h for bfloat16 overloads (#2192) * Include cuda_bf16.h for bfloat16 overloads * Add NO_GPU_MULTI(Eig) in cuda backend --- mlx/backend/cuda/kernels/fp16_math.cuh | 33 +------------------------- mlx/backend/cuda/primitives.cu | 1 + 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index 931c55ff7..edbd953de 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,44 +2,13 @@ #pragma once +#include #include #include #include namespace mlx::core::cu { -/////////////////////////////////////////////////////////////////////////////// -// Missing C++ operator overrides for CUDA 7. -/////////////////////////////////////////////////////////////////////////////// - -#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - -#define MLX_DEFINE_BF16_OP(OP) \ - __forceinline__ __device__ __nv_bfloat16 operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -#define MLX_DEFINE_BF16_CMP(OP) \ - __forceinline__ __device__ bool operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -MLX_DEFINE_BF16_OP(+) -MLX_DEFINE_BF16_OP(-) -MLX_DEFINE_BF16_OP(*) -MLX_DEFINE_BF16_OP(/) -MLX_DEFINE_BF16_CMP(>) -MLX_DEFINE_BF16_CMP(<) -MLX_DEFINE_BF16_CMP(>=) -MLX_DEFINE_BF16_CMP(<=) - -#undef MLX_DEFINE_BF16_OP -#undef MLX_DEFINE_BF16_CMP - -#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index dc6edf606..defdc746a 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -140,6 +140,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast { From 48ef3e74e27a3ea620adc5fb5ae22be15613e67f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 May 2025 08:38:49 -0700 Subject: [PATCH 25/37] reduce vjp for all and any (#2193) --- mlx/primitives.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c2bb59c05..5f2bfdda4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3548,7 +3548,7 @@ std::vector Reduce::vjp( } else { - throw std::runtime_error("Reduce type VJP not yet implemented."); + return {zeros_like(in, stream())}; } } From 0654543dcca1c69b4fa745eeee981fa8394dae89 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 18 May 2025 00:18:43 -0700 Subject: [PATCH 26/37] Add complex eigh (#2191) --- mlx/array.h | 4 + mlx/backend/cpu/eigh.cpp | 177 ++++++++++++++++++++++++++++-------- mlx/backend/cpu/lapack.h | 40 +++++--- mlx/linalg.cpp | 17 +++- python/tests/test_linalg.py | 7 ++ 5 files changed, 190 insertions(+), 55 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index d9fcfc58e..98eef2e33 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -224,6 +224,10 @@ class array { // Not copyable Data(const Data& d) = delete; Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } ~Data() { d(buffer); } diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index b50f2c722..58d3634e8 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -12,6 +12,133 @@ namespace mlx::core { namespace { +template +struct EighWork {}; + +template +struct EighWork< + T, + typename std::enable_if::value>::type> { + using R = T; + + char jobz; + char uplo; + int N; + int lwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) { + T work; + int iwork; + syevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, T* values) { + syevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &liwork, + &info); + } +}; + +template <> +struct EighWork> { + using T = std::complex; + using R = float; + + char jobz; + char uplo; + int N; + int lwork; + int lrwork; + int liwork; + int info; + std::vector buffers; + + EighWork(char jobz_, char uplo_, int N_) + : jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) { + T work; + R rwork; + int iwork; + heevd( + &jobz, + &uplo, + &N, + nullptr, + &N, + nullptr, + &work, + &lwork, + &rwork, + &lrwork, + &iwork, + &liwork, + &info); + lwork = static_cast(work.real()); + lrwork = static_cast(rwork); + liwork = iwork; + buffers.emplace_back(allocator::malloc(sizeof(T) * lwork)); + buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork)); + buffers.emplace_back(allocator::malloc(sizeof(int) * liwork)); + } + + void run(T* vectors, R* values) { + heevd( + &jobz, + &uplo, + &N, + vectors, + &N, + values, + static_cast(buffers[0].buffer.raw_ptr()), + &lwork, + static_cast(buffers[1].buffer.raw_ptr()), + &lrwork, + static_cast(buffers[2].buffer.raw_ptr()), + &liwork, + &info); + if (jobz == 'V') { + // We have pre-transposed the vectors but we also must conjugate them + // when they are complex. + // + // We could vectorize this but it is so fast in comparison to heevd that + // it doesn't really matter. + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + *vectors = std::conj(*vectors); + vectors++; + } + } + } + } +}; + template void eigh_impl( array& vectors, @@ -19,8 +146,10 @@ void eigh_impl( const std::string& uplo, bool compute_eigenvectors, Stream stream) { + using R = typename EighWork::R; + auto vec_ptr = vectors.data(); - auto eig_ptr = values.data(); + auto eig_ptr = values.data(); char jobz = compute_eigenvectors ? 'V' : 'N'; auto& encoder = cpu::get_command_encoder(stream); @@ -33,49 +162,17 @@ void eigh_impl( N = vectors.shape(-1), size = vectors.size()]() mutable { // Work query - int lwork = -1; - int liwork = -1; - int info; - { - T work; - int iwork; - syevd( - &jobz, - &uplo, - &N, - nullptr, - &N, - nullptr, - &work, - &lwork, - &iwork, - &liwork, - &info); - lwork = static_cast(work); - liwork = iwork; - } + EighWork work(jobz, uplo, N); - auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; - auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; + // Work loop for (size_t i = 0; i < size / (N * N); ++i) { - syevd( - &jobz, - &uplo, - &N, - vec_ptr, - &N, - eig_ptr, - static_cast(work_buf.buffer.raw_ptr()), - &lwork, - static_cast(iwork_buf.buffer.raw_ptr()), - &liwork, - &info); + work.run(vec_ptr, eig_ptr); vec_ptr += N * N; eig_ptr += N; - if (info != 0) { + if (work.info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << work.info; throw std::runtime_error(msg.str()); } } @@ -131,6 +228,10 @@ void Eigh::eval_cpu( eigh_impl( vectors, values, uplo_, compute_eigenvectors_, stream()); break; + case complex64: + eigh_impl>( + vectors, values, uplo_, compute_eigenvectors_, stream()); + break; default: throw std::runtime_error( "[Eigh::eval_cpu] only supports float32 or float64."); diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index 411742d56..b242093ff 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -2,14 +2,14 @@ #pragma once -// Required for Visual Studio. -// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md -#ifdef _MSC_VER #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex -#endif +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include @@ -32,7 +32,7 @@ #endif -#define INSTANTIATE_LAPACK_TYPES(FUNC) \ +#define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ @@ -42,12 +42,24 @@ } \ } -INSTANTIATE_LAPACK_TYPES(geqrf) -INSTANTIATE_LAPACK_TYPES(orgqr) -INSTANTIATE_LAPACK_TYPES(syevd) -INSTANTIATE_LAPACK_TYPES(geev) -INSTANTIATE_LAPACK_TYPES(potrf) -INSTANTIATE_LAPACK_TYPES(gesvdx) -INSTANTIATE_LAPACK_TYPES(getrf) -INSTANTIATE_LAPACK_TYPES(getri) -INSTANTIATE_LAPACK_TYPES(trtri) +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(geev) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e0f4ec2e6..144f9a880 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -27,6 +27,15 @@ void check_float(Dtype dtype, const std::string& prefix) { } } +void check_float_or_complex(Dtype dtype, const std::string& prefix) { + if (dtype != float32 && dtype != float64 && dtype != complex64) { + std::ostringstream msg; + msg << prefix << " Arrays must have type float32, float64 or complex64. " + << "Received array with type " << dtype << "."; + throw std::invalid_argument(msg.str()); + } +} + Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } @@ -493,7 +502,7 @@ void validate_eig( const StreamOrDevice& stream, const std::string fname) { check_cpu_stream(stream, fname); - check_float(a.dtype(), fname); + check_float_or_complex(a.dtype(), fname); if (a.ndim() < 2) { std::ostringstream msg; @@ -513,9 +522,10 @@ array eigvalsh( StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigvalsh]"); Shape out_shape(a.shape().begin(), a.shape().end() - 1); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); return array( std::move(out_shape), - a.dtype(), + eigval_type, std::make_shared(to_stream(s), UPLO, false), {a}); } @@ -525,9 +535,10 @@ std::pair eigh( std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { validate_eig(a, s, "[linalg::eigh]"); + Dtype eigval_type = a.dtype() == complex64 ? float32 : a.dtype(); auto out = array::make_arrays( {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, - {a.dtype(), a.dtype()}, + {eigval_type, a.dtype()}, std::make_shared(to_stream(s), UPLO, true), {a}); return std::make_pair(out[0], out[1]); diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f65da1ff7..f5eeda837 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -423,6 +423,13 @@ class TestLinalg(mlx_tests.MLXTestCase): A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 check_eigs_and_vecs(A_np) + # Test with complex inputs + A_np = ( + np.random.randn(8, 8, 2).astype(np.float32).view(np.complex64).squeeze(-1) + ) + A_np = A_np + A_np.T.conj() + check_eigs_and_vecs(A_np) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array From 8576e6fe3606bf5b805162fd5f4a7803a9a0d349 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 18 May 2025 06:05:11 -0700 Subject: [PATCH 27/37] fix conv2d bug + faster conv 1d (#2195) * fix conv2d bug + faster conv 1d * revert sort + flaky test --- mlx/backend/metal/conv.cpp | 268 +++++++++--------- .../steel/conv/loaders/loader_channel_l.h | 8 +- .../steel/conv/loaders/loader_channel_n.h | 4 +- mlx/ops.cpp | 6 +- python/tests/test_conv.py | 21 ++ python/tests/test_vmap.py | 1 + 6 files changed, 170 insertions(+), 138 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 35ed3d44e..6b4b70d47 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu( /*copies = */ copies); } -void conv_1D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const std::vector& padding, - const std::vector& wt_strides, - const std::vector& wt_dilation, - const std::vector& in_dilation, - int groups, - bool flip) { - // Make conv params - MLXConvParams<1> conv_params{ - /* const int N = */ static_cast(in.shape(0)), - /* const int C = */ static_cast(in.shape(2)), - /* const int O = */ static_cast(wt.shape(0)), - /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, - /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, - /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, - /* const int str[NDIM] = */ {wt_strides[0]}, - /* const int pad[NDIM] = */ {padding[0]}, - /* const int kdil[NDIM] = */ {wt_dilation[0]}, - /* const int idil[NDIM] = */ {in_dilation[0]}, - /* const size_t in_strides[NDIM + 2] = */ - {in.strides()[0], in.strides()[1], in.strides()[2]}, - /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, - /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2]}, - /* const int groups = */ groups, - /* const bool flip = */ flip}; - - // Direct to explicit gemm conv - if (groups > 1) { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } -} - -void slow_conv_2D_gpu( - const Stream& s, - metal::Device& d, - const array& in, - const array& wt, - array out, - const MLXConvParams<2>& conv_params) { - int bm = 16, bn = 8; - int tm = 4, tn = 4; - - std::ostringstream kname; - kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn - << "_tm" << tm << "_tn" << tn; - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder.set_compute_pipeline_state(kernel); - - size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; - - size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm); - size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn); - size_t grid_dim_z = conv_params.N; - - MTL::Size group_dims = MTL::Size(bm, bn, 1); - MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); - - compute_encoder.set_input_array(in, 0); - compute_encoder.set_input_array(wt, 1); - compute_encoder.set_output_array(out, 2); - - compute_encoder.set_bytes(conv_params, 3); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void implicit_gemm_conv_2D_gpu( const Stream& s, metal::Device& d, @@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void dispatch_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params, + std::vector& copies) { + bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; + bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; + bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; + + if (is_idil_one && conv_params.groups > 1) { + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + } + + // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + bool channels_large = (conv_params.C + conv_params.O) >= 256; + if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + } + + // Direct to implicit gemm conv + if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && + (conv_params.O <= 16 || conv_params.O % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); + } + + // Direct to explicit gemm conv + else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + +void conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation, + const std::vector& in_dilation, + int groups, + bool flip, + std::vector& copies) { + bool is_idil_one = in_dilation[0] == 1; + int C = in.shape(2); + int O = wt.shape(0); + const int C_per_group = in.shape(2) / groups; + const int O_per_group = wt.shape(0) / groups; + + // Direct to implicit gemm conv + if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + MLXConvParams<2> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ C, + /* const int O = */ O, + /* const int iS[NDIM] = */ {static_cast(in.shape(1)), 1}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1)), 1}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1)), 1}, + /* const int str[NDIM] = */ {wt_strides[0], 1}, + /* const int pad[NDIM] = */ {padding[0], 0}, + /* const int kdil[NDIM] = */ {wt_dilation[0], 1}, + /* const int idil[NDIM] = */ {in_dilation[0], 1}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], 0, in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], 0, out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + return; + } + + // Make conv params + MLXConvParams<1> conv_params{ + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(2)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, + /* const int str[NDIM] = */ {wt_strides[0]}, + /* const int pad[NDIM] = */ {padding[0]}, + /* const int kdil[NDIM] = */ {wt_dilation[0]}, + /* const int idil[NDIM] = */ {in_dilation[0]}, + /* const size_t in_strides[NDIM + 2] = */ + {in.strides()[0], in.strides()[1], in.strides()[2]}, + /* const size_t wt_strides[NDIM + 2] = */ + {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, + /* const size_t out_strides[NDIM + 2] = */ + {out.strides()[0], out.strides()[1], out.strides()[2]}, + /* const int groups = */ groups, + /* const bool flip = */ flip}; + + // Direct to explicit gemm conv + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -808,57 +865,7 @@ void conv_2D_gpu( /* const int groups = */ groups, /* const bool flip = */ flip, }; - - bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; - bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; - bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - - if (is_idil_one && groups > 1) { - const int C_per_group = conv_params.C / groups; - const int O_per_group = conv_params.O / groups; - - if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && - conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && - conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && - conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && - conv_params.wt_strides[1] == conv_params.wS[1] && - conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { - return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - if ((C_per_group <= 4 || C_per_group % 16 == 0) && - (O_per_group <= 16 || O_per_group % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } else { - return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); - } - } - - // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; - bool channels_large = (conv_params.C + conv_params.O) >= 256; - if (!flip && is_stride_one && is_kdil_one && is_idil_one && - conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && - channels_large) { - return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); - } - - // Direct to implicit gemm conv - if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && - (conv_params.O <= 16 || conv_params.O % 16 == 0)) { - return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { - return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to explicit gemm conv - else { - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); - } + dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } void conv_3D_gpu( @@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_dilation_, input_dilation_, groups_, - flip_); + flip_, + copies); } // Throw error else { diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h index dad496e81..d52642b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index 56027916e..b0b98d21a 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0c18cccfe..a72c2bc85 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3584,21 +3584,21 @@ Shape conv_out_shape( if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for " + msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (kernel_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for " + msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } if (input_dilation.size() != spatial_dims) { std::ostringstream msg; - msg << "[conv] Invalid input dilation " << input_dilation << "for " + msg << "[conv] Invalid input dilation " << input_dilation << " for " << spatial_dims << "D convolution."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 7d63e4751..9fe11286d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertEqual(grads.shape, k_shape) + def test_1d_conv_with_2d(self): + x = mx.random.uniform(shape=(2, 10, 16)) + y = mx.random.normal(shape=(16, 3, 16)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + + x = mx.random.uniform(shape=(2, 10, 4)) + y = mx.random.normal(shape=(4, 3, 4)) + + out = mx.conv1d(x, y, padding=1) + out_2d = mx.conv2d( + mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0) + ) + + self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index e571678d3..ddfceb0a1 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: From 237f9e58a892798aa9a4bfd6e83a864fa3358904 Mon Sep 17 00:00:00 2001 From: Cheng Date: Mon, 19 May 2025 22:10:44 +0900 Subject: [PATCH 28/37] Fix BEFORE keyword in target_include_directories (#2204) --- mlx/backend/cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 54d651005..f9695f66a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -36,7 +36,7 @@ FetchContent_Declare( cccl URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") FetchContent_MakeAvailable(cccl) -target_include_directories(mlx PRIVATE BEFORE "${cccl_SOURCE_DIR}/include") +target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") # Use fixed version of NVTX. FetchContent_Declare( From 0359bf02c99f4beff9a596431865d8211b654714 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 19 May 2025 11:23:38 -0700 Subject: [PATCH 29/37] Nearest upsample (#2202) --- python/mlx/nn/layers/upsample.py | 11 ++++++++++- python/tests/test_upsample.py | 11 ++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 1f2ffd3da..e6bd282af 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) + M = int(scale * N) + indices = mx.arange(M, dtype=mx.float32) + if M > N: + indices = (indices + 0.5) * (N / M) - 0.5 + indices = indices.round() + else: + indices = indices * (N / M) + shape = [1] * ndims + shape[dim] = -1 + return indices.astype(mx.uint32).reshape(shape) def _linear_indices(N, scale, align_corners, dim, ndims): diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 402c7b0ca..86f41b6e8 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase): align_corners=align_corner, )(in_mx) mode_pt = { + "nearest": "nearest", "linear": "bilinear", "cubic": "bicubic", }[mode] @@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase): in_pt, scale_factor=scale_factor, mode=mode_pt, - align_corners=align_corner, + align_corners=align_corner if mode != "nearest" else None, ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) self.assertEqual(out_pt.shape, out_mx.shape) @@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase): ((4, 4), (0.5, 0.5)), ((7, 7), (2.0, 2.0)), ((10, 10), (0.2, 0.2)), + ((10, 10), (0.3, 0.3)), ((11, 21), (3.0, 3.0)), ((11, 21), (3.0, 2.0)), ): - # only test linear and cubic interpolation - # there will be numerical difference in nearest - # due to different indices selection. - for mode in ("cubic", "linear"): + for mode in ("cubic", "linear", "nearest"): for align_corner in (False, True): + if mode == "nearest" and align_corner: + continue run_upsample( N, C, From eebe73001affcb424171e9d49657e508f70a9201 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 May 2025 13:10:44 -0700 Subject: [PATCH 30/37] fix large arg reduce (#2206) --- mlx/backend/metal/kernels/arg_reduce.metal | 20 +++++++++++--------- mlx/backend/metal/primitives.cpp | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 7f1075ad9..4a83d8e57 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -80,9 +80,10 @@ template const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { @@ -104,17 +105,18 @@ template // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. - auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; + auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); + auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value - uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; @@ -144,7 +146,7 @@ template } // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize, simd_size); + uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } @@ -154,7 +156,7 @@ template } // Finally write the output - if (lid == 0) { + if (lid.x == 0) { out[out_idx] = best.index; } } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6e42b29c9..705c3ea76 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -182,8 +182,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { (thread_group_size + simd_size - 1) / simd_size * simd_size; assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); - size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + auto gd = get_2d_grid_dims(out.shape(), out.strides()); + MTL::Size grid_dims = MTL::Size(thread_group_size, gd.width, gd.height); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); From ab8883dd55745a12d385b91ef26ea794d2a45bdb Mon Sep 17 00:00:00 2001 From: Clement Liaw Date: Tue, 20 May 2025 07:39:11 -0700 Subject: [PATCH 31/37] include mlx::core::version() symbols in the mlx static library (#2207) --- mlx/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 4ba9b33dd..ce921b276 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -21,7 +21,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) # Define MLX_VERSION only in the version.cpp file. -add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) +add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_link_libraries(mlx PRIVATE $) From 4cbe6052147420ee09ae855d57a81cef3467af15 Mon Sep 17 00:00:00 2001 From: Jack Wind Date: Tue, 20 May 2025 13:22:26 -0400 Subject: [PATCH 32/37] Feat: Allow per-target Metal debug flags (#2201) * feat: allow per-target Metal debug flags * formatting fix --- cmake/extension.cmake | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cmake/extension.cmake b/cmake/extension.cmake index 3270b0056..13db804a1 100644 --- a/cmake/extension.cmake +++ b/cmake/extension.cmake @@ -11,13 +11,14 @@ include(CMakeParseArguments) # Args: TARGET: Custom target to be added for the metal library TITLE: Name of # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency -# files (like headers) +# files (like headers) DEBUG: Boolean, if true, enables debug compile options +# for this specific library. If not provided, uses global MLX_METAL_DEBUG. # # clang format on macro(mlx_build_metallib) # Parse args - set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) + set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -26,6 +27,10 @@ macro(mlx_build_metallib) # Collect compile options set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + if(MLX_METAL_DEBUG OR MTLLIB_DEBUG) + set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only + -frecord-sources) + endif() # Prepare metallib build command add_custom_command( From 35c87741cf2450c96c6e52afead61eec81c45e2a Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 11:42:48 +0900 Subject: [PATCH 33/37] Build for compute capability 70 instead of 75 (#2209) --- mlx/backend/cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index f9695f66a..d62a69846 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -25,7 +25,7 @@ target_compile_options(mlx # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. set(MLX_CUDA_ARCHITECTURES - "75;80" + "70;80" CACHE STRING "CUDA architectures") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES From 7774b87cbda51c5e34f1471d8e76767350368a05 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 23:25:03 +0900 Subject: [PATCH 34/37] Remove redundant simd_sum in logsumexp (#2210) --- mlx/backend/metal/kernels/logsumexp.h | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index b6898e31e..93744e15d 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -134,10 +134,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } From 79071bfba4f012517859bcfdd9032123d16cc6b6 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 23:25:16 +0900 Subject: [PATCH 35/37] Fix out-of-bounds default value in logsumexp/softmax (#2213) --- mlx/backend/metal/kernels/logsumexp.h | 4 ++-- mlx/backend/metal/kernels/softmax.h | 4 ++-- tests/ops_tests.cpp | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 93744e15d..c746050b3 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index b36b73bd8..6ea4ac732 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5e2bae5a0..8833424a6 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1036,6 +1036,9 @@ TEST_CASE("test reduction ops") { x = array({-inf, -inf}); CHECK_EQ(logsumexp(x).item(), -inf); + x = repeat(array(-inf), 5000); + CHECK_EQ(logsumexp(x).item(), -inf); + x = array({0.0f, -inf}); CHECK_EQ(logsumexp(x).item(), 0.0f); From 55b4062dd8c71d4499a430012b49f676da91818a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 21 May 2025 17:13:04 -0700 Subject: [PATCH 36/37] copyright in docs (#2214) --- docs/src/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/conf.py b/docs/src/conf.py index abc68c3a2..d9dd32ad1 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,7 +10,7 @@ import mlx.core as mx # -- Project information ----------------------------------------------------- project = "MLX" -copyright = "2023, MLX Contributors" +copyright = "2023, Apple" author = "MLX Contributors" version = ".".join(mx.__version__.split(".")[:3]) release = version From 54a71f270a671d2b31c493c98f27a49fe217a6f1 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 23 May 2025 22:14:58 +0900 Subject: [PATCH 37/37] Remove unused defines (#2217) --- CMakeLists.txt | 3 +++ mlx/backend/cuda/CMakeLists.txt | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab8aea443..4bf8d2d3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,9 @@ target_include_directories( mlx PUBLIC $ $) +# Do not add mlx_EXPORTS define for shared library. +set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") + FetchContent_Declare( fmt GIT_REPOSITORY https://github.com/fmtlib/fmt.git diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d62a69846..7ebe68324 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,8 +16,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) -target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) - # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>")