From d2a94f9e6adf202b5faf4fd7596f5a77ff0b77a6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 12 Mar 2025 13:08:19 -0700 Subject: [PATCH 01/14] Only compile warnings as errors for circle (#1957) --- .circleci/config.yml | 10 +++++++--- CMakeLists.txt | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index fc64b004b..9c8cb31a3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -92,10 +92,12 @@ jobs: - run: name: Install Python package command: | - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ + CMAKE_ARGS="-DMLX_BUILD_METAL=OFF + CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py build_ext --inplace - CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ + CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ + CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ python3 setup.py develop - run: @@ -146,7 +148,9 @@ jobs: name: Install Python package command: | source env/bin/activate - DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v + DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ + CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \ + pip install -e . -v - run: name: Generate package stubs command: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 50d12cdfc..0601ea292 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,7 +77,6 @@ include(FetchContent) cmake_policy(SET CMP0135 NEW) add_library(mlx) -set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON) if(MLX_BUILD_METAL) set(METAL_LIB "-framework Metal") From 2770a1024082eb10cce6bc0ac589ad089e7be611 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 13 Mar 2025 19:13:09 -0700 Subject: [PATCH 02/14] fix grad with inplace updates (#1961) --- python/src/transforms.cpp | 12 +++++++++++- python/tests/test_autograd.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8d78a1bde..4a5e2e6ac 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -176,7 +176,17 @@ auto py_value_and_grad( // Call the python function py_value_out = fun(*tree[0], **tree[1]); - tree_fill(tree, arrays); + // Replace the tracers with the originals. Don't overwrite + // locations which were written to during the call to fun + int index = 0; + tree_visit_update(tree, [&](nb::handle node) { + auto replace_arr = nb::cast(node); + if (replace_arr.id() == a[index].id()) { + return nb::cast(arrays[index++]); + } else { + return nb::cast(replace_arr); + } + }); // Validate the return value of the python function if (!nb::isinstance(py_value_out)) { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ffccd85fc..350b09837 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -746,6 +746,7 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.checkpoint, ]: if mx.metal.is_available(): + mx.synchronize(mx.default_stream(mx.default_device())) mem_pre = mx.metal.get_active_memory() else: mem_pre = 0 @@ -790,6 +791,20 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.grad(fun)(arrs) self.assertEqual(init_id, id(arrs[0])) + def test_grad_with_inplace_update(self): + def loss_fn(model): + model[1] = mx.array(2.0) + return model[0] + + model = [ + mx.array(0.0), + mx.array(1.0), + ] + + grad_fn = mx.grad(loss_fn) + grad_fn(model) + self.assertEqual(model[1].item(), 2.0) + if __name__ == "__main__": unittest.main() From c6ea2ba329ddf204c52500ba68d3f79bbb0e7283 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 16 Mar 2025 07:13:24 -0700 Subject: [PATCH 03/14] Use same accumulation precision in gemv as gemm (#1962) * use same accumulation precision in gemv as gemm * faster * fix compile --- mlx/backend/metal/kernels/gemv.metal | 68 ++++++++++++++----------- mlx/backend/metal/kernels/gemv_masked.h | 52 ++++++++++--------- python/tests/test_blas.py | 12 +++++ 3 files changed, 79 insertions(+), 53 deletions(-) diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 28cadd50a..f21c35d97 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -23,7 +23,8 @@ template < const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = float> struct GEMVKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -60,28 +61,32 @@ struct GEMVKernel { MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + template static METAL_FUNC void - load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } + template static METAL_FUNC void load_safe( const device T* src, - thread T dst[TN], + thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); } } } @@ -97,7 +102,7 @@ struct GEMVKernel { const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -106,9 +111,9 @@ struct GEMVKernel { (void)lid; // Thread local accumulation results - thread T result[TM] = {0}; + thread AccT result[TM] = {0}; thread T inter[TN]; - thread T v_coeff[TN]; + thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -142,7 +147,7 @@ struct GEMVKernel { // Loop over in_vec in blocks of blockN for (int i = 0; i < n_iter; ++i) { - load_unsafe(in_vec, v_coeff, bn); + load_unsafe(in_vec, v_coeff, bn); // Per thread work loop int mat_offset = 0; @@ -164,7 +169,7 @@ struct GEMVKernel { } if (leftover > 0) { - load_safe(in_vec, v_coeff, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); // Per thread work loop MLX_MTL_PRAGMA_UNROLL @@ -191,7 +196,7 @@ struct GEMVKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { @@ -217,10 +222,11 @@ struct GEMVKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { if (kDoAxpby) { - out_vec[out_row + tm] = static_cast(alpha) * result[tm] + + out_vec[out_row + tm] = + static_cast(alpha) * static_cast(result[tm]) + static_cast(beta) * bias[(out_row + tm) * bias_stride]; } else { - out_vec[out_row + tm] = result[tm]; + out_vec[out_row + tm] = static_cast(result[tm]); } } } @@ -239,7 +245,8 @@ template < const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ + const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ + typename AccT = float> struct GEMVTKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -282,7 +289,7 @@ struct GEMVTKernel { const constant float& alpha [[buffer(7)]], const constant float& beta [[buffer(8)]], const constant int& bias_stride [[buffer(14)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -291,10 +298,9 @@ struct GEMVTKernel { (void)lid; // Thread local accumulation results - T result[TN] = {0}; + AccT result[TN] = {0}; T inter[TN]; - T v_coeff[TM]; - + AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -330,16 +336,17 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); } MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { + auto vc = float(v_coeff[tm]); for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + result[tn] += vc * inter[tn]; } } @@ -348,7 +355,7 @@ struct GEMVTKernel { if (leftover > 0) { for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -374,7 +381,7 @@ struct GEMVTKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -400,10 +407,11 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { if (kDoAxpby) { - out_vec[out_col + j] = static_cast(alpha) * result[j] + + out_vec[out_col + j] = + static_cast(alpha) * static_cast(result[j]) + static_cast(beta) * bias[(out_col + j) * bias_stride]; } else { - out_vec[out_col + j] = result[j]; + out_vec[out_col + j] = static_cast(result[j]); } } } @@ -445,7 +453,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -553,7 +561,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; @@ -660,7 +668,7 @@ template < uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets @@ -761,8 +769,8 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + using gemv_kernel = GEMVTKernel; + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h index 48acf1d61..75bc7354c 100644 --- a/mlx/backend/metal/kernels/gemv_masked.h +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -44,7 +44,8 @@ template < const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> struct GEMVKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -91,28 +92,32 @@ struct GEMVKernel { MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + template static METAL_FUNC void - load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } + template static METAL_FUNC void load_safe( const device T* src, - thread T dst[TN], + thread U dst[TN], const int src_offset = 0, const int src_size = TN) { if (src_offset + TN <= src_size) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src[src_offset + tn]; + dst[tn] = static_cast(src[src_offset + tn]); } } else { // Edgecase MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); } } } @@ -128,7 +133,7 @@ struct GEMVKernel { const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -137,9 +142,9 @@ struct GEMVKernel { (void)lid; // Thread local accumulation results - thread T result[TM] = {0}; + thread AccT result[TM] = {0}; thread T inter[TN]; - thread T v_coeff[TN]; + thread AccT v_coeff[TN]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -225,7 +230,7 @@ struct GEMVKernel { T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - load_unsafe(in_vec, v_coeff, bn); + load_unsafe(in_vec, v_coeff, bn); // Apply scale if (has_mul_operand_mask) { @@ -267,7 +272,7 @@ struct GEMVKernel { T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - load_safe(in_vec, v_coeff, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); // Apply scale if (has_mul_operand_mask) { @@ -310,7 +315,7 @@ struct GEMVKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; if (thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { @@ -335,7 +340,7 @@ struct GEMVKernel { if (simdN == 0 && thrN == 0) { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = result[tm]; + out_vec[out_row + tm] = static_cast(result[tm]); } } } @@ -354,7 +359,8 @@ template < const int SM, /* Simdgroup rows (in threads) */ const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> struct GEMVTKernel { MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsN = BN * SN; @@ -405,7 +411,7 @@ struct GEMVTKernel { const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - threadgroup T* tgp_memory [[threadgroup(0)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -414,9 +420,9 @@ struct GEMVTKernel { (void)lid; // Thread local accumulation results - T result[TN] = {0}; + AccT result[TN] = {0}; T inter[TN]; - T v_coeff[TM]; + AccT v_coeff[TM]; const int thrM = SN != 32 ? simd_lid / SN : 0; const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); @@ -511,7 +517,7 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); } // Apply scale @@ -549,7 +555,7 @@ struct GEMVTKernel { } for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + v_coeff[tm] = static_cast(in_vec[bm + tm]); if (has_mul_operand_mask) { v_coeff[tm] *= block_scale; @@ -587,7 +593,7 @@ struct GEMVTKernel { // Threadgroup accumulation results if (needs_tgp_reduction) { - threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; if (thrM == 0) { MLX_MTL_PRAGMA_UNROLL for (int tn = 0; tn < TN; tn++) { @@ -612,7 +618,7 @@ struct GEMVTKernel { if (cm == 0 && out_col < out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { - out_vec[out_col + j] = result[j]; + out_vec[out_col + j] = static_cast(result[j]); } } } @@ -655,7 +661,7 @@ template < uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; @@ -755,7 +761,7 @@ template < uint simd_lid [[thread_index_in_simdgroup]]) { using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory + threadgroup float tgp_memory [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; constexpr bool has_operand_mask = !metal::is_same_v; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index fdeaea98a..985ca5ffb 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1146,6 +1146,18 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def test_gemv_gemm_same_precision(self): + mx.random.seed(0) + N = 256 + if mx.metal.is_available(): + t = mx.bfloat16 + a = mx.random.normal([1, N]).astype(t) + b = mx.concatenate([a, a], axis=0).astype(t) + c = mx.random.normal([N, 64]).astype(t) + out_gemv = a @ c + out_gemm = (b @ c)[0] + self.assertTrue(mx.allclose(out_gemv, out_gemm)) + if __name__ == "__main__": unittest.main() From 45ad06aac898a64abee55569b7594b5a3d2f228d Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Tue, 18 Mar 2025 22:12:24 +0800 Subject: [PATCH 04/14] Fix typo; Fix lint warning when reuse the same name (#1968) * Fix typo; Fix lint warning when reuse the same name * Add missing period --- python/mlx/optimizers/optimizers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 36068403d..1c45865a1 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -50,19 +50,19 @@ class Optimizer: dict_keys(['step', 'learning_rate', 'weight', 'bias']) """ - # Iniatilize the optimizer state to match the parameter state + # Initialize the optimizer state to match the parameter state def update_state(params, state): if isinstance(params, (list, tuple)): state = list(state) for i in range(len(state)): state[i] = update_state(params[i], state[i]) if len(state) != len(params): - state.extend(tree_map(lambda x: {}, params[len(state) :])) + state.extend(tree_map(lambda _: {}, params[len(state) :])) return type(params)(state) elif isinstance(params, dict): for k, v in params.items(): if k not in state: - state[k] = tree_map(lambda x: {}, v) + state[k] = tree_map(lambda _: {}, v) else: state[k] = update_state(v, state[k]) return state @@ -79,6 +79,7 @@ class Optimizer: Args: parameter (mx.array): A single parameter that will be optimized. + state (dict): The optimizer's state. """ raise NotImplementedError() @@ -148,10 +149,10 @@ class Optimizer: """ if isinstance(param, Callable): self._schedulers[name] = param - param = param(self.step) + parameter = param(self.step) else: - param = mx.array(param) - self.state[name] = param + parameter = mx.array(param) + self.state[name] = parameter class MultiOptimizer(Optimizer): From 0a9777aa5cf67b7b1ed5dbe7d17eedd2228c7171 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 18 Mar 2025 23:12:40 +0900 Subject: [PATCH 05/14] Do not define MLX_VERSION globally (#1966) --- CMakeLists.txt | 3 +-- mlx/CMakeLists.txt | 6 +++++- mlx/version.cpp | 7 +------ 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0601ea292..672b9810c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ if(NOT MLX_VERSION) string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}") set(_patch ${CMAKE_MATCH_1}) set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}") + set(MLX_VERSION ${MLX_PROJECT_VERSION}) else() string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION ${MLX_VERSION}) @@ -41,8 +42,6 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) -add_compile_definitions("MLX_VERSION=${MLX_VERSION}") - # --------------------- Processor tests ------------------------- message( STATUS diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 8b12cc14f..76fe389d4 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -17,9 +17,13 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) +# Define MLX_VERSION only in the version.cpp file. +add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) +target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") +target_link_libraries(mlx PRIVATE $) + if(MSVC) # Disable some MSVC warnings to speed up compilation. target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804) diff --git a/mlx/version.cpp b/mlx/version.cpp index 92d8b3ea6..2a8e3bc9b 100644 --- a/mlx/version.cpp +++ b/mlx/version.cpp @@ -2,15 +2,10 @@ #include -#include "mlx/version.h" - -#define STRINGIFY(x) #x -#define TOSTRING(x) STRINGIFY(x) - namespace mlx::core { std::string version() { - return TOSTRING(MLX_VERSION); + return MLX_VERSION; } } // namespace mlx::core From 377915075038a782f7b06652d051c6d2732e1a31 Mon Sep 17 00:00:00 2001 From: Chunyang Wen Date: Thu, 20 Mar 2025 02:24:04 +0800 Subject: [PATCH 06/14] refactor: all use schedule (#1973) --- python/mlx/optimizers/schedulers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index a8c0354f3..67e4e29cd 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable: array(0.0999961, dtype=float32) """ - def scheduler(step): + def schedule(step): s = mx.minimum(step, decay_steps) decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) return end + decay * (init - end) - return scheduler + return schedule def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable: @@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable that indicates when to transition between schedules. Example: - >>> warmup = optim.linear_schedule(0, 1e-1, steps=10) + >>> linear = optim.linear_schedule(0, 1e-1, steps=10) >>> cosine = optim.cosine_decay(1e-1, 200) - >>> lr_schedule = optim.join_schedules([warmup, cosine], [10]) + >>> lr_schedule = optim.join_schedules([linear, cosine], [10]) >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) @@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: Example: - >>> warmup = optim.linear_schedule(0, 1e-1, 100) - >>> optimizer = optim.Adam(learning_rate=warmup) + >>> lr_schedule = optim.linear_schedule(0, 1e-1, 100) + >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) >>> for _ in range(101): optimizer.update({}, {}) @@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: if steps < 1: raise ValueError(f"steps must be greater than 0, but got {steps}.") - def step_fn(step): + def schedule(step): step = mx.minimum(step, steps) return step * ((end - init) / steps) + init - return step_fn + return schedule From f90206ad74f00046399727ef11f532c85b647af1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 19 Mar 2025 16:24:10 -0700 Subject: [PATCH 07/14] Guard nullptr dereference (#1972) * guard nullptr dereference * comment --- mlx/backend/metal/allocator.cpp | 3 +++ mlx/backend/no_metal/allocator.cpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index f2c95be20..8f5b28226 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -20,6 +20,9 @@ Allocator& allocator() { } void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } return static_cast(ptr_)->contents(); } diff --git a/mlx/backend/no_metal/allocator.cpp b/mlx/backend/no_metal/allocator.cpp index 27e2ea06f..0429ea53a 100644 --- a/mlx/backend/no_metal/allocator.cpp +++ b/mlx/backend/no_metal/allocator.cpp @@ -10,6 +10,9 @@ Allocator& allocator() { } void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } return static_cast(ptr_) + 1; } From 95e335db7b92fb3302e0456842975619dc8ccd9d Mon Sep 17 00:00:00 2001 From: jiyzhang Date: Thu, 20 Mar 2025 11:19:02 +0800 Subject: [PATCH 08/14] Update smooth_l1_loss in losses.py (#1974) According the definition of smooth_l1_loss, the line diff = predictions - targets Should be updated to diff = mx.abs(predictions - targets) After the modification, the result is consistent with PyTorch smooth_l1_loss --- python/mlx/nn/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index bccf45b16..58232363a 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -373,7 +373,7 @@ def smooth_l1_loss( f"targets shape {targets.shape}." ) - diff = predictions - targets + diff = mx.abs(predictions - targets) loss = mx.where( diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta ) From 3c164fca8cd0ee060a2e1ce916d5f5149959e18b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 07:19:47 -0700 Subject: [PATCH 09/14] Fix multistream GPU deadlock (#1969) * fix multistream GPU deadlock * comments --- mlx/backend/cpu/encoder.h | 16 +++++++++++++++- mlx/backend/cpu/eval.cpp | 8 ++------ mlx/backend/metal/device.cpp | 5 +---- mlx/backend/metal/metal.cpp | 6 +----- mlx/transforms.cpp | 2 +- python/tests/test_eval.py | 11 +++++++++++ 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index aae64fb5e..b8e33ca81 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -9,6 +9,9 @@ namespace mlx::core::cpu { +// Number of dispatches per scheduler task +constexpr int DISPATCHES_PER_TASK = 10; + struct CommandEncoder { CommandEncoder(Stream stream) : stream_(stream) {} @@ -39,13 +42,24 @@ struct CommandEncoder { template void dispatch(F&& f, Args&&... args) { + num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; auto task = std::bind(std::forward(f), std::forward(args)...); - scheduler::enqueue(stream_, std::move(task)); + if (num_ops_ == 0) { + scheduler::notify_new_task(stream_); + auto task_wrap = [s = stream_, task = std::move(task)]() mutable { + task(); + scheduler::notify_task_completion(s); + }; + scheduler::enqueue(stream_, std::move(task_wrap)); + } else { + scheduler::enqueue(stream_, std::move(task)); + } } private: Stream stream_; std::vector temporaries_; + int num_ops_{0}; }; CommandEncoder& get_command_encoder(Stream stream); diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 04811e737..b23c8d561 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -33,12 +33,8 @@ void eval(array& arr) { buffers.erase(it); } auto& encoder = cpu::get_command_encoder(s); - scheduler::notify_new_task(s); - encoder.dispatch([s, - buffers = std::move(buffers), - temps = std::move(encoder.temporaries())]() { - scheduler::notify_task_completion(s); - }); + encoder.dispatch([buffers = std::move(buffers), + temps = std::move(encoder.temporaries())]() {}); } } // namespace mlx::core::cpu diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e28989a5c..930e570e2 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -19,9 +19,6 @@ namespace mlx::core::metal { namespace { -// TODO nicer way to set this or possibly expose as an environment variable -constexpr int MAX_BUFFERS_PER_QUEUE = 12; - constexpr const char* default_mtllib_path = METAL_PATH; auto get_metal_version() { @@ -256,7 +253,7 @@ Device::~Device() { void Device::new_queue(int index) { auto thread_pool = metal::new_scoped_memory_pool(); - auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); + auto q = device_->newCommandQueue(); debug_set_stream_queue_label(q, index); if (!q) { throw std::runtime_error( diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index ba13b4b59..a9a1bc4f6 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -75,11 +75,7 @@ void finalize(Stream s) { auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); d.end_encoding(s.index); - scheduler::notify_new_task(s); - cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); + cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); d.commit_command_buffer(s.index); d.get_command_buffer(s.index); } diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index f01082418..958899bec 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -21,7 +21,7 @@ namespace mlx::core { -static constexpr int MAX_ACTIVE_TASKS = 100; +static constexpr int MAX_ACTIVE_TASKS = 10; /* This class is only meant to be used in eval * for synchronizing with the main thread. */ diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 37e31f80b..510402b06 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -174,6 +174,17 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.metal.get_peak_memory() self.assertEqual(pre, post) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_multistream_deadlock(self): + s1 = mx.default_stream(mx.gpu) + s2 = mx.new_stream(mx.gpu) + + x = mx.array(1.0) + x = mx.abs(x, stream=s1) + for _ in range(1000): + x = mx.abs(x, stream=s2) + mx.eval(x) + if __name__ == "__main__": unittest.main() From 9adcd1a650b6f97b9d9e6e07e96c09e437184e9a Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 20 Mar 2025 11:01:32 -0700 Subject: [PATCH 10/14] Support fused masking in Attention (#1924) * Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests --- benchmarks/python/sdpa_bench.py | 194 ++++++++++-------- .../steel/attn/kernels/steel_attention.h | 103 +++++++++- .../steel/attn/kernels/steel_attention.metal | 37 ++-- mlx/backend/metal/kernels/steel/attn/mma.h | 2 +- mlx/backend/metal/kernels/steel/attn/params.h | 8 + .../metal/scaled_dot_product_attention.cpp | 38 +++- mlx/fast.cpp | 83 ++++++-- mlx/fast.h | 3 +- mlx/fast_primitives.h | 8 +- python/src/fast.cpp | 12 +- python/tests/test_fast_sdpa.py | 164 +++++++++++++++ 11 files changed, 504 insertions(+), 148 deletions(-) diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index 23383475e..5eb789de0 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -28,11 +28,34 @@ def bench(f, *args): return (e - s) * 1e-9 -def mlx_sdpa_fused_inner(q, k, v, scale): - return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) +def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): + np_dtype = getattr(np, dtype) + + shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) + shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) + + scale = 1.0 / math.sqrt(D) + + q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + if mask is not None: + if mask == "additive": + mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype) + mask = mx.array(mask_np) + elif mask == "bool": + mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 + mask = mx.array(mask_np) + + return q_mx, k_mx, v_mx, scale, mask -def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): +def mlx_ref_attn(q, k, v, scale=1.0, mask=None): q_dtype = q.dtype q = q * mx.array(scale, q_dtype) n_q_heads = q.shape[-3] @@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): B = q.shape[0] L = q.shape[2] + kL = k.shape[2] if n_repeats > 1: q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) @@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): v = mx.expand_dims(v, 2) scores = q @ mx.swapaxes(k, -1, -2) - if f32softmax: - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) - else: - scores = mx.softmax(scores, axis=-1) + + if mask is not None: + + if mask == "causal": + q_offset = max(0, kL - L) + q_indices = mx.arange(q_offset, q_offset + L) + k_indices = mx.arange(kL) + mask = q_indices[:, None] >= k_indices[None] + + if n_repeats > 1 and mask.ndim >= 3: + if mask.shape[-3] == 1: + mask = mx.expand_dims(mask, -3) + else: + mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) + + if mask.dtype == mx.bool_: + scores = mx.where(mask, scores, -np.float32(np.inf)) + else: + scores += mask + + scores = mx.softmax(scores, axis=-1, precise=True) out = scores @ v if n_repeats > 1: @@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): return out -def mlx_spda_unfused(q, k, v, scale, transpose): - q_out = q +def mlx_fused_attn(q, k, v, scale, mask): + return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + + +def do_attention(f, q, k, v, scale, mask=None, transpose=False): if transpose: - k = mx.transpose(k, (0, 2, 1, 3)) - v = mx.transpose(v, (0, 2, 1, 3)) + q_t = mx.transpose(q, (0, 2, 1, 3)) + k_t = mx.transpose(k, (0, 2, 1, 3)) + v_t = mx.transpose(v, (0, 2, 1, 3)) + o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) + return mx.transpose(o_t, (0, 2, 1, 3)) + else: + return f(q, k, v, scale=scale, mask=mask) + + +def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False): + q_out = q for i in range(N_iter_func): - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale) - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) + q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose) mx.eval(q_out) return q_out -def mlx_spda_fused(q, k, v, scale, transpose): - q_out = q - if transpose: - k = mx.transpose(k, (0, 2, 1, 3)) - v = mx.transpose(v, (0, 2, 1, 3)) - - for i in range(N_iter_func): - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - q_out = mlx_sdpa_fused_inner(q_out, k, v, scale) - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - - mx.eval(q_out) - return q_out - - -def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True): - shape_q = ( - (B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim) - ) - shape_kv = ( - (B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim) +def bench_shape( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None +): + q_mx, k_mx, v_mx, scale, mask = prepare_inputs( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype ) - q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) - k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) - v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) + time_mlx_unfused = bench( + do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) + time_mlx_fused = bench( + do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) - scale = math.sqrt(1.0 / head_dim) + o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose) + o_mlx_unfused = do_attention( + mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) - q_mx = mx.array(q_np) - k_mx = mx.array(k_np) - v_mx = mx.array(v_np) + atol = 1e-5 if dtype == "float32" else 2e-4 - time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) - time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose) - - if transpose: - q_mx = mx.transpose(q_mx, (0, 2, 1, 3)) - k_mx = mx.transpose(k_mx, (0, 2, 1, 3)) - v_mx = mx.transpose(v_mx, (0, 2, 1, 3)) - - o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale) - o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True) - - atol = 1e-5 if np_dtype == np.float32 else 1e-4 - - if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol): + if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol): print( - f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" + f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" ) return time_mlx_fused, time_mlx_unfused @@ -151,39 +173,51 @@ if __name__ == "__main__": ( 1, 128, 128, 64, 32, 32), ( 1, 256, 256, 64, 32, 32), ( 1, 512, 512, 64, 32, 32), - ( 1, 1024, 1024, 64, 32, 32), - ( 1, 2048, 2048, 64, 32, 32), - ( 1, 4096, 4096, 64, 32, 32), + ( 1, 1024, 1024, 64, 32, 8), + ( 1, 2048, 2048, 64, 32, 8), + ( 1, 4096, 4096, 64, 32, 8), ) shapes_80 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - ( 1, 1024, 1024, 80, 32, 32), - ( 1, 2048, 2048, 80, 32, 32), - ( 1, 4096, 4096, 80, 32, 32), + ( 1, 1024, 1024, 80, 32, 8), + ( 1, 2048, 2048, 80, 32, 8), + ( 1, 4096, 4096, 80, 32, 8), ) shapes_128 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - ( 1, 1024, 1024, 128, 32, 32), - ( 1, 2048, 2048, 128, 32, 32), - ( 1, 4096, 4096, 128, 32, 32), + ( 1, 1024, 1024, 128, 32, 8), + ( 1, 2048, 2048, 128, 32, 8), + ( 1, 4096, 4096, 128, 32, 8), ) # fmt: on shapes = shapes_64 + shapes_80 + shapes_128 - print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") + masks = [None, "bool", "causal"] + + print( + " B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%" + ) for dtype in dtypes: for transpose in transposes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: - np_dtype = getattr(np, dtype) - time_mlx_fused, time_mlx_unfused = bench_shape( - B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose - ) - diff = time_mlx_unfused / time_mlx_fused - 1.0 - t_str = 1 if transpose else 0 - print( - f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" - ) + for mask_in in masks: + time_mlx_fused, time_mlx_unfused = bench_shape( + B, + qsl, + ksl, + head_dim, + n_q_heads, + n_kv_heads, + dtype, + transpose, + mask_in, + ) + diff = time_mlx_unfused / time_mlx_fused - 1.0 + t_str = 1 if transpose else 0 + print( + f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" + ) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index b2e70ef8d..a8469e0ff 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -1,4 +1,4 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2024-25 Apple Inc. using namespace mlx::steel; @@ -9,6 +9,9 @@ using namespace mlx::steel; constant bool align_Q [[function_constant(200)]]; constant bool align_K [[function_constant(201)]]; +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; + template struct TransformScale { T scale; @@ -69,6 +72,7 @@ template < int BD, int WM, int WN, + typename MaskType = float, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( const device T* Q [[buffer(0)]], @@ -76,6 +80,8 @@ template < const device T* V [[buffer(2)]], device T* O [[buffer(3)]], const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -102,6 +108,11 @@ template < tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Seqeunce + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + // Prepare threadgroup memory constexpr short padQ = 16 / sizeof(T); constexpr short padK = 16 / sizeof(T); @@ -203,7 +214,7 @@ template < // Load Q blocks apply scale if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ)); + loader_q.load_safe(short2(BD, params->qL_rem)); } else { loader_q.load_unsafe(); } @@ -221,12 +232,19 @@ template < max_score[i] = Limits::min; } + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + } + // Loop over KV seq length - for (int kb = 0; kb < params->NK; kb++) { + for (int kb = 0; kb < kb_lim; kb++) { // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); if (!align_K && kb == (params->NK_aligned)) { - loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + loader_k.load_safe(short2(BD, params->kL_rem)); } else { loader_k.load_unsafe(); } @@ -250,12 +268,11 @@ template < tile_matmad(Stile, Qtile, Ktile, Stile); } - // Mask out of length sequence + // Mask out length sequence if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = -metal::numeric_limits::infinity(); - const short lim = params->kL - params->NK_aligned * BK; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -264,7 +281,7 @@ template < short col_pos = sn + (j * stile_t::kFragCols); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { - if ((col_pos + jj) >= lim) { + if ((col_pos + jj) >= params->kL_rem) { Stile.frag_at(i, j)[jj] = neg_inf; } } @@ -272,11 +289,78 @@ template < } } + // Mask out if causal + if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + } + } + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); // Load V blocks if (!align_K && kb == (params->NK_aligned)) { - loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + loader_v.load_safe(short2(BD, params->kL_rem)); } else { loader_v.load_unsafe(); } @@ -367,8 +451,7 @@ template < O += (tm + sm) * params->O_strides[2] + sn; if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - auto dst_tile_dims = - short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm)); + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0d05a6932..fee28fed1 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -1,4 +1,4 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2024-25 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" @@ -6,26 +6,23 @@ #include "mlx/backend/metal/kernels/steel/attn/attn.h" #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" -#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \ - template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \ - [[kernel]] void attention( \ - const device dtype* Q [[buffer(0)]], \ - const device dtype* K [[buffer(1)]], \ - const device dtype* V [[buffer(2)]], \ - device dtype* O [[buffer(3)]],\ - const constant AttnParams* params [[buffer(4)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention, dtype, bq, bk, bd, wm, wn, mtype, float) -#define instantiate_attn_shapes_helper(iname, itype) \ - instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \ - instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ - instantiate_attn(iname, itype, 32, 32, 64, 4, 1) +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) -instantiate_attn_shapes_helper(float16, half); -instantiate_attn_shapes_helper(bfloat16, bfloat16_t); +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) -instantiate_attn_shapes_helper(float32, float); +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat16_t); + +instantiate_attn_mask_helper(float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index 525c50e8f..db5127c33 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -111,7 +111,7 @@ struct BaseMMAFrag { for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[i * kElemCols + j] = - static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index 4f9680412..f1cf09fad 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -26,11 +26,19 @@ struct AttnParams { int NQ_aligned; ///< Number of full query blocks int NK_aligned; ///< Number of full key/value blocks + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) }; +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 7fbd63022..f7ec004a6 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal( const array& k, const array& v, const float scale, - array& o) { + array& o, + bool do_causal_ = false, + const std::optional& mask = std::nullopt) { using namespace mlx::steel; int wm = 4; @@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal( const bool align_Q = (qL % bq) == 0; const bool align_K = (kL % bk) == 0; + const bool has_mask = !!mask; + const bool do_causal = do_causal_; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, - }; + {&has_mask, MTL::DataType::DataTypeBool, 300}, + {&do_causal, MTL::DataType::DataTypeBool, 301}}; std::ostringstream kname; // clang-format off @@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal( << "_bq" << bq << "_bk" << bk << "_bd" << bd - << "_wm" << wm << "_wn" << wn; // clang-format on + << "_wm" << wm + << "_wn" << wn + << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on std::string base_name = kname.str(); // clang-format off kname << "_align_Q_" << (align_Q ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n') + << "_has_mask_" << (has_mask ? 't' : 'n') + << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal( /* int NQ_aligned = */ NQ_aligned, /* int NK_aligned = */ NK_aligned, + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ (kL - qL), + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal( compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); + if (mask) { + auto m = *mask; + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { + m.strides(0), m.strides(1), m.strides(2)}}; + + compute_encoder.set_bytes(mask_params, 5); + compute_encoder.set_input_array(m, 6); + } + MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); @@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu( // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { - return arr.strides(3) == 1; + return arr.strides(-1) == 1; }; // We are in vector mode ie single query @@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu( {str_oB, str_oH, str_oL, str_oD}, flags); - sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} + : std::nullopt; + + sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 136c7796a..342078a24 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -567,7 +567,7 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::optional& mask, + const std::variant& mask /* = {}*/, const std::optional memory_efficient_threshold, StreamOrDevice s) { for (const auto& tensor : {queries, keys, values}) { @@ -578,10 +578,29 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } } - if (mask && (*mask).ndim() > 4) { + + bool do_causal = false; + bool has_mask = !std::holds_alternative(mask); + bool has_str_mask = has_mask && std::holds_alternative(mask); + bool has_arr_mask = has_mask && std::holds_alternative(mask); + bool has_bool_mask = false; + + if (has_str_mask) { + if (std::get(mask) != "causal") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] invalid mask option '" + << std::get(mask) << "'. Must be 'causal', or an array."; + throw std::invalid_argument(msg.str()); + } else { + do_causal = true; + } + } + + if (has_arr_mask && (std::get(mask)).ndim() > 4) { std::ostringstream msg; msg << "[scaled_dot_product_attention] the mask with shape " - << (*mask).shape() << " expected to have at most rank 4"; + << (std::get(mask)).shape() + << " expected to have at most rank 4"; throw std::invalid_argument(msg.str()); } @@ -631,9 +650,11 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (mask) { + if (has_arr_mask) { // Check type - if (promote_types(mask->dtype(), final_type) != final_type) { + auto mask_arr = std::get(mask); + has_bool_mask = mask_arr.dtype() == bool_; + if (promote_types(mask_arr.dtype(), final_type) != final_type) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Mask type must promote to output type. " << final_type << "."; @@ -642,9 +663,10 @@ array scaled_dot_product_attention( // Check shape auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); - if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) { + if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape() + msg << "[scaled_dot_product_attention] Mask with shape " + << mask_arr.shape() << " does not broadcast to implicit scores with shape " << mask_shape << "."; throw std::invalid_argument(msg.str()); @@ -662,7 +684,7 @@ array scaled_dot_product_attention( threshold = std::max(1, memory_efficient_threshold.value()); } - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s]( + auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -676,9 +698,21 @@ array scaled_dot_product_attention( v = expand_dims(v, 2, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); - if (inputs.size() > 3) { + if (inputs.size() > 3 || do_causal) { // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] - auto mask = inputs[3]; + auto mask = inputs.back(); + + if (do_causal) { + int kL = k.shape(-2); + int qL = q.shape(-2); + int q_off = (kL - qL) < 0 ? 0 : (kL - qL); + auto q_idx = arange(q_off, q_off + qL, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + mask = greater_equal(q_idx, k_idx, s); + } + if (n_repeats > 1 && mask.ndim() >= 3) { if (mask.shape(-3) == 1) { mask = expand_dims(mask, -3, s); @@ -702,9 +736,10 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - const size_t value_head_dim = v.shape(-1); - const size_t query_head_dim = q.shape(-1); - const size_t query_sequence_length = q.shape(2); + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && @@ -712,27 +747,33 @@ array scaled_dot_product_attention( const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && - sdpa_full_supported_head_dim && stream.device == Device::gpu; + const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); + + const bool supports_sdpa_full = query_sequence_length >= threshold && + sdpa_full_supported_mask && sdpa_full_supported_head_dim && + stream.device == Device::gpu; const bool supports_sdpa_vector = (query_sequence_length <= 8) && - (query_sequence_length <= k.shape(-2)) && - (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_mask && sdpa_vector_supported_head_dim && stream.device == Device::gpu; const bool implementation_supports_use_case = supports_sdpa_full || supports_sdpa_vector; std::vector inputs = {q, k, v}; - if (mask) { - inputs.push_back(*mask); + if (has_arr_mask) { + inputs.push_back(std::get(mask)); } if (implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), final_type, - std::make_shared(stream, fallback, scale), + std::make_shared( + stream, fallback, scale, do_causal), std::move(inputs)); } return fallback(inputs)[0]; @@ -741,7 +782,7 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return scale_ == a_other.scale_; + return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; } array pack_and_quantize( diff --git a/mlx/fast.h b/mlx/fast.h index fe93de85e..b9db6d462 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "mlx/utils.h" @@ -47,7 +48,7 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::optional& mask = std::nullopt, + const std::variant& mask = {}, const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index ec97fe0ca..4d9e505ee 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -206,8 +206,9 @@ class ScaledDotProductAttention : public Custom { explicit ScaledDotProductAttention( Stream stream, std::function(std::vector)> fallback, - const float scale) - : Custom(stream, fallback), scale_(scale) {} + const float scale, + const bool do_causal) + : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -225,12 +226,13 @@ class ScaledDotProductAttention : public Custom { DEFINE_PRINT(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { - return std::make_pair(nullptr, scale_); + return std::make_tuple(nullptr, scale_, do_causal_); } private: std::function(std::vector)> fallback_; float scale_; + bool do_causal_; }; class AffineQuantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index fc2cbd41d..95b7dcc9a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) { "memory_efficient_threshold"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. @@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (array, optional): A boolean or additive mask to apply to the - query-key scores. The mask can have at most 4 dimensions and must - be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an - additive mask is given its type must promote to the promoted - type of ``q``, ``k``, and ``v``. + mask (Union[None, str, array], optional): A causal, boolean or additive + mask to apply to the query-key scores. The mask can have at most 4 + dimensions and must be broadcast-compatible with the shape + ``[B, N, T_q, T_kv]``. If an additive mask is given its type must + promote to the promoted type of ``q``, ``k``, and ``v``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 5426ea236..a269847de 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -6,6 +6,91 @@ import mlx_tests import numpy as np +def mlx_ref_attn(q, k, v, scale=1.0, mask=None): + q_dtype = q.dtype + q = q * mx.array(scale, q_dtype) + n_q_heads = q.shape[-3] + n_kv_heads = k.shape[-3] + n_repeats = n_q_heads // n_kv_heads + + B = q.shape[0] + L = q.shape[2] + kL = k.shape[2] + + if n_repeats > 1: + q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) + k = mx.expand_dims(k, 2) + v = mx.expand_dims(v, 2) + + scores = q @ mx.swapaxes(k, -1, -2) + + if mask is not None: + + if mask == "causal": + q_offset = max(0, kL - L) + q_indices = mx.arange(q_offset, q_offset + L) + k_indices = mx.arange(kL) + mask = q_indices[:, None] >= k_indices[None] + + if n_repeats > 1 and mask.ndim >= 3: + if mask.shape[-3] == 1: + mask = mx.expand_dims(mask, -3) + else: + mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) + + if mask.dtype == mx.bool_: + scores = mx.where(mask, scores, -np.float32(np.inf)) + else: + scores += mask + + scores = mx.softmax(scores, axis=-1, precise=True) + + out = scores @ v + if n_repeats > 1: + out = mx.reshape(out, [B, n_q_heads, L, -1]) + + return out + + +def do_attention(f, q, k, v, scale, mask=None, transpose=False): + if transpose: + q_t = mx.transpose(q, (0, 2, 1, 3)) + k_t = mx.transpose(k, (0, 2, 1, 3)) + v_t = mx.transpose(v, (0, 2, 1, 3)) + o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) + return mx.transpose(o_t, (0, 2, 1, 3)) + else: + return f(q, k, v, scale=scale, mask=mask) + + +def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): + np.random.seed(0) + np_dtype = getattr(np, dtype) + + shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) + shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) + + scale = 1.0 / math.sqrt(D) + + q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + if mask is not None: + if mask == "additive": + mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype) + mask = mx.array(mask_np) + elif mask == "bool": + mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 + mask = mx.array(mask_np) + + return q_mx, k_mx, v_mx, scale, mask + + # SDPA for MHA (n_heads == n_kv_heads) def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) @@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) +class TestSDPA(mlx_tests.MLXTestCase): + @property + def dtypes(self): + return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + + def test_sdpa(self): + if not mx.metal.is_available(): + return + + # fmt: off + shapes_64 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 128, 128, 64, 32, 32), + ( 1, 64, 128, 64, 32, 32), + ( 1, 65, 128, 64, 32, 8), + ( 1, 64, 127, 64, 32, 8), + ( 1, 65, 127, 64, 32, 8), + ( 1, 127, 65, 64, 32, 8), + ) + + shapes_128 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 128, 128, 128, 32, 8), + ( 1, 64, 128, 128, 32, 8), + ( 1, 65, 127, 128, 32, 8), + ( 1, 127, 65, 128, 32, 8), + ) + # fmt: on + + shapes = shapes_64 + shapes_128 + masks = [None, "additive", "bool", "causal"] + transposes = (False, True) + + for dtype in self.dtypes: + for t in transposes: + for mask_str in masks: + for B, qL, kL, D, qH, kH in shapes: + with self.subTest( + B=B, + qsl=qL, + ksl=kL, + head_dim=D, + n_q_heads=qH, + n_kv_heads=kH, + mask=mask_str, + transpose=t, + dtype=dtype, + ): + + np.random.seed(0) + q_mx, k_mx, v_mx, scale, mask = prepare_inputs( + B, qL, kL, D, qH, kH, mask_str, t, dtype + ) + + out_ref = do_attention( + mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t + ) + + out_fst = do_attention( + mx.fast.scaled_dot_product_attention, + q_mx, + k_mx, + v_mx, + scale, + mask, + t, + ) + + atol = 2e-5 if dtype == "float32" else 3e-4 + + self.assertListEqual( + list(out_ref.shape), list(out_fst.shape) + ) + + self.assertTrue( + mx.allclose(out_fst, out_ref, atol=atol, rtol=atol) + ) + + if __name__ == "__main__": unittest.main(failfast=True) From b42d13ec8443c299d32b4b254161413bfaa72acb Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 20 Mar 2025 14:25:38 -0700 Subject: [PATCH 11/14] Update attention tests to show diff, disable array masks (#1978) --- mlx/fast.cpp | 4 ++-- python/tests/test_fast_sdpa.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 342078a24..ed0d9fbe5 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -748,8 +748,8 @@ array scaled_dot_product_attention( (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); - const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || - (query_sequence_length <= key_sequence_length && do_causal); + const bool sdpa_full_supported_mask = + !has_mask || (query_sequence_length <= key_sequence_length && do_causal); const bool supports_sdpa_full = query_sequence_length >= threshold && sdpa_full_supported_mask && sdpa_full_supported_head_dim && diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index a269847de..4ea573564 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -72,8 +72,8 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): scale = 1.0 / math.sqrt(D) - q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) - k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype) v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) q_mx = mx.array(q_np) @@ -524,9 +524,8 @@ class TestSDPA(mlx_tests.MLXTestCase): list(out_ref.shape), list(out_fst.shape) ) - self.assertTrue( - mx.allclose(out_fst, out_ref, atol=atol, rtol=atol) - ) + diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref) + self.assertLessEqual(mx.max(diff).item(), atol) if __name__ == "__main__": From 005e7efa647dcb7a73e77dd62d3467d54176a68e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 14:53:12 -0700 Subject: [PATCH 12/14] fix mask in sdpa (#1980) * fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani --- .../metal/scaled_dot_product_attention.cpp | 6 +-- mlx/fast.cpp | 41 +++++++------------ python/tests/test_fast_sdpa.py | 16 ++++++++ 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f7ec004a6..c5b544852 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. - #include #include "mlx/backend/common/compiled.h" @@ -59,7 +58,7 @@ void sdpa_full_self_attention_metal( << "_bq" << bq << "_bk" << bk << "_bd" << bd - << "_wm" << wm + << "_wm" << wm << "_wn" << wn << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on @@ -67,7 +66,7 @@ void sdpa_full_self_attention_metal( // clang-format off kname << "_align_Q_" << (align_Q ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n') << "_has_mask_" << (has_mask ? 't' : 'n') << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on @@ -117,6 +116,7 @@ void sdpa_full_self_attention_metal( if (mask) { auto m = *mask; + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { m.strides(0), m.strides(1), m.strides(2)}}; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ed0d9fbe5..ac3cfe042 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -650,29 +650,6 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (has_arr_mask) { - // Check type - auto mask_arr = std::get(mask); - has_bool_mask = mask_arr.dtype() == bool_; - if (promote_types(mask_arr.dtype(), final_type) != final_type) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask type must promote to output type. " - << final_type << "."; - throw std::invalid_argument(msg.str()); - } - // Check shape - auto mask_shape = queries.shape(); - mask_shape.back() = keys.shape(-2); - if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask with shape " - << mask_arr.shape() - << " does not broadcast to implicit scores with shape " << mask_shape - << "."; - throw std::invalid_argument(msg.str()); - } - } - auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); @@ -748,8 +725,8 @@ array scaled_dot_product_attention( (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); - const bool sdpa_full_supported_mask = - !has_mask || (query_sequence_length <= key_sequence_length && do_causal); + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); const bool supports_sdpa_full = query_sequence_length >= threshold && sdpa_full_supported_mask && sdpa_full_supported_head_dim && @@ -765,7 +742,19 @@ array scaled_dot_product_attention( std::vector inputs = {q, k, v}; if (has_arr_mask) { - inputs.push_back(std::get(mask)); + // Check type + auto mask_arr = std::get(mask); + has_bool_mask = mask_arr.dtype() == bool_; + if (promote_types(mask_arr.dtype(), final_type) != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask type must promote to output type. " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + // Broadcast mask + auto mask_shape = queries.shape(); + mask_shape.back() = keys.shape(-2); + inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } if (implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 4ea573564..78e03159f 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -527,6 +527,22 @@ class TestSDPA(mlx_tests.MLXTestCase): diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref) self.assertLessEqual(mx.max(diff).item(), atol) + def test_sdpa_broadcast_mask(self): + mask = mx.array(True) + D = 64 + Nq = 4 + Nkv = 1 + scale = 1.0 + L = 256 + + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + if __name__ == "__main__": unittest.main(failfast=True) From 1177d283954facd3d1c40ff1ea59929549045c12 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 15:12:22 -0700 Subject: [PATCH 13/14] patch bump (#1981) --- mlx/version.h | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/version.h b/mlx/version.h index f244dcb16..275c74c73 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,8 +3,8 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 23 -#define MLX_VERSION_PATCH 2 +#define MLX_VERSION_MINOR 24 +#define MLX_VERSION_PATCH 0 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/setup.py b/setup.py index 72bc2dba3..f1769b21f 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.23.2"), + version=get_version("0.24.0"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", From 7b7e2352cdaa97b063889dd87d59b1a6f4a8bed2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 16:48:43 -0700 Subject: [PATCH 14/14] fix malloc or wait deadlock (#1976) --- docs/src/dev/extensions.rst | 6 ++--- examples/extensions/axpby/axpby.cpp | 8 +++--- mlx/allocator.cpp | 26 ++----------------- mlx/allocator.h | 8 ++---- mlx/backend/common/binary.h | 10 +++---- mlx/backend/common/common.cpp | 2 +- mlx/backend/common/compiled.cpp | 4 +-- mlx/backend/common/copy.h | 4 +-- mlx/backend/common/load.cpp | 2 +- mlx/backend/common/ternary.h | 6 ++--- mlx/backend/cpu/arg_reduce.cpp | 2 +- mlx/backend/cpu/conv.cpp | 8 +++--- mlx/backend/cpu/distributed.cpp | 6 ++--- mlx/backend/cpu/eigh.cpp | 7 +++-- mlx/backend/cpu/fft.cpp | 2 +- mlx/backend/cpu/indexing.cpp | 4 +-- mlx/backend/cpu/inverse.cpp | 4 +-- mlx/backend/cpu/luf.cpp | 7 +++-- mlx/backend/cpu/masked_mm.cpp | 4 +-- mlx/backend/cpu/matmul.cpp | 2 +- mlx/backend/cpu/primitives.cpp | 14 +++++----- mlx/backend/cpu/qrf.cpp | 14 +++++----- mlx/backend/cpu/quantized.cpp | 10 +++---- mlx/backend/cpu/reduce.cpp | 2 +- mlx/backend/cpu/scan.cpp | 2 +- mlx/backend/cpu/softmax.cpp | 2 +- mlx/backend/cpu/sort.cpp | 4 +-- mlx/backend/cpu/svd.cpp | 12 ++++----- mlx/backend/cpu/unary.h | 4 +-- mlx/backend/metal/allocator.cpp | 21 ++++++++------- mlx/backend/metal/allocator.h | 5 ++-- mlx/backend/metal/conv.cpp | 12 ++++----- mlx/backend/metal/copy.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/backend/metal/fence.cpp | 2 +- mlx/backend/metal/fft.cpp | 11 ++++---- mlx/backend/metal/hadamard.cpp | 4 +-- mlx/backend/metal/indexing.cpp | 4 +-- mlx/backend/metal/matmul.cpp | 12 ++++----- mlx/backend/metal/metal.h | 17 +++++++----- mlx/backend/metal/normalization.cpp | 18 ++++++------- mlx/backend/metal/primitives.cpp | 14 +++++----- mlx/backend/metal/quantized.cpp | 10 +++---- mlx/backend/metal/reduce.cpp | 8 +++--- mlx/backend/metal/rope.cpp | 4 +-- .../metal/scaled_dot_product_attention.cpp | 15 +++++------ mlx/backend/metal/scan.cpp | 2 +- mlx/backend/metal/slicing.cpp | 2 +- mlx/backend/metal/softmax.cpp | 2 +- mlx/backend/metal/sort.cpp | 19 +++++++------- mlx/backend/metal/unary.cpp | 4 +-- mlx/backend/no_metal/metal.cpp | 5 +++- mlx/transforms.cpp | 9 ++++++- python/src/metal.cpp | 16 +++++------- python/tests/test_eval.py | 12 +++++++++ 55 files changed, 201 insertions(+), 217 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 4272371cb..b8c3a4995 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`. float alpha_, float beta_, mx::Stream stream) { - // Allocate the output with `malloc_or_wait` which synchronously allocates - // memory, potentially waiting if the system is under memory pressure - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); @@ -393,7 +391,7 @@ below. auto& d = metal::device(s.device); // Allocate output memory - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); // Resolve name of kernel std::ostringstream kname; diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index d3663e71e..291246617 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -72,9 +72,7 @@ void axpby_impl( float alpha_, float beta_, mx::Stream stream) { - // Allocate the output with `malloc_or_wait` which synchronously allocates - // memory, potentially waiting if the system is under memory pressure - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); @@ -160,12 +158,12 @@ void Axpby::eval_gpu( // Allocate output memory with strides based on specialization if (contiguous_kernel) { out.set_data( - mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), + mx::allocator::malloc(x.data_size() * out.itemsize()), x.data_size(), x.strides(), x.flags()); } else { - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); } // Resolve name of kernel (corresponds to axpby.metal) diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index 8d7273a78..2d97a6db3 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -9,7 +9,7 @@ namespace mlx::core::allocator { Buffer malloc(size_t size) { - auto buffer = allocator().malloc(size, /* allow_swap */ true); + auto buffer = allocator().malloc(size); if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; @@ -22,7 +22,7 @@ void free(Buffer buffer) { allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size, bool) { +Buffer CommonAllocator::malloc(size_t size) { void* ptr = std::malloc(size + sizeof(size_t)); if (ptr != nullptr) { *static_cast(ptr) = size; @@ -41,26 +41,4 @@ size_t CommonAllocator::size(Buffer buffer) const { return *static_cast(buffer.ptr()); } -Buffer malloc_or_wait(size_t size) { - auto buffer = allocator().malloc(size); - - while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) { - scheduler::wait_for_one(); - buffer = allocator().malloc(size); - } - - // Try swapping if needed - if (size && !buffer.ptr()) { - buffer = allocator().malloc(size, /* allow_swap = */ true); - } - - if (size && !buffer.ptr()) { - std::ostringstream msg; - msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; - throw std::runtime_error(msg.str()); - } - - return buffer; -} - } // namespace mlx::core::allocator diff --git a/mlx/allocator.h b/mlx/allocator.h index 2809c7f7f..d4e3e1d6e 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -32,14 +32,10 @@ Buffer malloc(size_t size); void free(Buffer buffer); -// Wait for running tasks to finish and free up memory -// if allocation fails -Buffer malloc_or_wait(size_t size); - class Allocator { /** Abstract base class for a memory allocator. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; + virtual Buffer malloc(size_t size) = 0; virtual void free(Buffer buffer) = 0; virtual size_t size(Buffer buffer) const = 0; @@ -56,7 +52,7 @@ Allocator& allocator(); class CommonAllocator : public Allocator { /** A general CPU allocator. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) override; + virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 0c56b71a0..ac6e1891d 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -44,14 +44,14 @@ inline void set_binary_op_output_data( switch (bopt) { case BinaryOpType::ScalarScalar: out.set_data( - allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); + allocator::malloc(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc_or_wait(b.data_size() * out.itemsize()), + allocator::malloc(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); @@ -62,7 +62,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(a); } else { out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), + allocator::malloc(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -75,7 +75,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), + allocator::malloc(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -88,7 +88,7 @@ inline void set_binary_op_output_data( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { out.copy_shared_buffer(b); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } break; } diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 3eb98d09a..57813e062 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector& inputs, array& out) { void NumberOfElements::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); double numel = 1; for (auto ax : axes_) { diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index dfa9e700a..f7b5598ab 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -188,7 +188,7 @@ void compiled_allocate_outputs( } for (; o < outputs.size(); ++o) { outputs[o].set_data( - allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + allocator::malloc(data_size * outputs[o].itemsize()), data_size, strides, flags); @@ -211,7 +211,7 @@ void compiled_allocate_outputs( } } for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); } } } diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 5d8f7a58e..0c9f28c94 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { return true; } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); return false; } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); return false; } } diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 3f194f1e2..ce41963de 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) { namespace mlx::core { void Load::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto read_task = [out_ptr = out.data(), size = out.size(), itemsize = out.itemsize(), diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index ad6df2fc4..d98dd8d68 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -48,12 +48,12 @@ inline void set_ternary_op_output_data( switch (topt) { case TernaryOpType::ScalarScalarScalar: out.set_data( - allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); + allocator::malloc(out.itemsize()), 1, b.strides(), b.flags()); break; case TernaryOpType::VectorVectorVector: if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { out.set_data( - allocator::malloc_or_wait(out.itemsize() * b.data_size()), + allocator::malloc(out.itemsize() * b.data_size()), b.data_size(), b.strides(), b.flags()); @@ -64,7 +64,7 @@ inline void set_ternary_op_output_data( if (!((a.flags().row_contiguous && maybe_donate(a)) || (b.flags().row_contiguous && maybe_donate(b)) || (c.flags().row_contiguous && maybe_donate(c)))) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } break; } diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index c9bdc35b0..a8ba3efe2 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -68,7 +68,7 @@ void arg_reduce_dispatch( void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index 8259e0597..d52f92f8b 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1327,7 +1327,7 @@ void conv_3D_cpu( } // namespace void Convolution::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& in = inputs[0]; auto& wt = inputs[1]; diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 1d6d9de06..1afa027a8 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -30,7 +30,7 @@ void AllReduce::eval_cpu( if (in.is_donatable()) { out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } return in; } else { @@ -58,7 +58,7 @@ void AllGather::eval_cpu( assert(outputs.size() == 1); auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::all_gather(group(), in, outputs[0], stream()); if (copied) { auto& enc = cpu::get_command_encoder(stream()); @@ -87,7 +87,7 @@ void Recv::eval_cpu( assert(inputs.size() == 0); assert(outputs.size() == 1); - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::recv(group(), outputs[0], src_, stream()); } diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index 348a27199..b50f2c722 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -55,9 +55,8 @@ void eigh_impl( liwork = iwork; } - auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; - auto iwork_buf = - array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; for (size_t i = 0; i < size / (N * N); ++i) { syevd( &jobz, @@ -98,7 +97,7 @@ void Eigh::eval_cpu( ? outputs[1] : array(a.shape(), a.dtype(), nullptr, {}); - values.set_data(allocator::malloc_or_wait(values.nbytes())); + values.set_data(allocator::malloc(values.nbytes())); copy( a, diff --git a/mlx/backend/cpu/fft.cpp b/mlx/backend/cpu/fft.cpp index 90227575b..d9e5f8050 100644 --- a/mlx/backend/cpu/fft.cpp +++ b/mlx/backend/cpu/fft.cpp @@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector& inputs, array& out) { s *= out.itemsize(); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); std::vector shape; if (out.dtype() == float32) { diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 6a32dc1d4..70d6b3eb7 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -197,7 +197,7 @@ void dispatch_gather( } void Gather::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; std::vector inds; @@ -354,7 +354,7 @@ void dispatch_gather_axis( } void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; auto& inds = inputs[1]; diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index 9d9b71497..2e79addcb 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -11,7 +11,7 @@ namespace mlx::core { template void general_inv(T* inv, int N) { int info; - auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; + auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)}; // Compute LU factorization. getrf( /* m = */ &N, @@ -49,7 +49,7 @@ void general_inv(T* inv, int N) { } const int lwork = workspace_size; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Compute inverse. getri( diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index d85de146c..9ac9361a4 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -30,8 +30,7 @@ void luf_impl( auto strides = lu.strides(); strides[ndim - 1] = M; strides[ndim - 2] = 1; - lu.set_data( - allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags); + lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); copy_inplace( a, lu, @@ -44,8 +43,8 @@ void luf_impl( stream); auto a_ptr = lu.data(); - pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); - row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); + pivots.set_data(allocator::malloc(pivots.nbytes())); + row_indices.set_data(allocator::malloc(row_indices.nbytes())); auto pivots_ptr = pivots.data(); auto row_indices_ptr = row_indices.data(); size_t num_matrices = a.size() / (M * N); diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 75b8a5b3d..0be7c79ce 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[BlockMaskedMM::eval] Currently only supports float32."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; @@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[GatherMM::eval] Currently only supports float32."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 082531894..8ae99ab2d 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -115,7 +115,7 @@ void matmul_general( } void Matmul::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (inputs[0].shape(-1) == 0) { auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index a561f1de8..1dfae8524 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -21,7 +21,7 @@ namespace mlx::core { void reshape(const array& in, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); copy_inplace(in, out, CopyType::General, out.primitive().stream()); } else { shared_buffer_reshape(in, out_strides, out); @@ -39,7 +39,7 @@ static std::pair compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + offset.set_data(allocator::malloc(offset.itemsize())); } auto& encoder = cpu::get_command_encoder(stream); @@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector& inputs, array& out) { void Arange::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); switch (out.dtype()) { case bool_: throw std::runtime_error("Bool type unsupported for arange."); @@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); @@ -276,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto kptr = inputs[0].data(); auto cptr = out.data(); @@ -335,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { return; } auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto [in_offset, donated] = compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); copy_inplace( @@ -450,7 +450,7 @@ void View::eval_cpu(const std::vector& inputs, array& out) { } else { auto tmp = array( in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + tmp.set_data(allocator::malloc(tmp.nbytes())); if (in.dtype() == bool_) { auto in_tmp = array(in.shape(), uint8, nullptr, {}); in_tmp.copy_shared_buffer(in); diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 8c6f94140..9e01d188b 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { auto strides = in.strides(); strides[in.ndim() - 2] = 1; strides[in.ndim() - 1] = M; - in.set_data( - allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags); + in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); copy_inplace(a, in, CopyType::GeneralGeneral, stream); auto& encoder = cpu::get_command_encoder(stream); - q.set_data(allocator::malloc_or_wait(q.nbytes())); - r.set_data(allocator::malloc_or_wait(r.nbytes())); + q.set_data(allocator::malloc(q.nbytes())); + r.set_data(allocator::malloc(r.nbytes())); auto in_ptr = in.data(); auto r_ptr = r.data(); @@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { encoder.set_output_array(r); encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { int num_reflectors = std::min(M, N); - auto tau = - allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); + auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors); T optimal_work; int lwork = -1; @@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { // Update workspace size lwork = optimal_work; - auto work = allocator::malloc_or_wait(sizeof(T) * lwork); + auto work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { @@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { &lwork, &info); lwork = optimal_work; - work = allocator::malloc_or_wait(sizeof(T) * lwork); + work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 02bddab2f..f0ac9d57f 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto scales = ensure_row_contiguous(scales_pre); auto biases = ensure_row_contiguous(biases_pre); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps)); @@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto scales = ensure_row_contiguous_last_dims(scales_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps)); @@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu( auto [w, copied] = ensure_row_contiguous(inputs[0]); auto& out = outputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& scales = outputs[1]; auto& biases = outputs[2]; - scales.set_data(allocator::malloc_or_wait(scales.nbytes())); - biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); if (copied) { encoder.add_temporary(w); diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 3f0c3b2ae..ce25feb11 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -433,7 +433,7 @@ void reduce_dispatch_min_max( void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 205ae414d..1a44ebd39 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -244,7 +244,7 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { in = arr_copy; encoder.add_temporary(arr_copy); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 591bad020..78e4a3e68 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -129,7 +129,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 4439df61b..b00e301b8 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); @@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 86745e545..24d93f8e5 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -50,9 +50,9 @@ void svd_impl( array& s = outputs[1]; array& vt = outputs[2]; - u.set_data(allocator::malloc_or_wait(u.nbytes())); - s.set_data(allocator::malloc_or_wait(s.nbytes())); - vt.set_data(allocator::malloc_or_wait(vt.nbytes())); + u.set_data(allocator::malloc(u.nbytes())); + s.set_data(allocator::malloc(s.nbytes())); + vt.set_data(allocator::malloc(vt.nbytes())); encoder.set_output_array(u); encoder.set_output_array(s); @@ -64,7 +64,7 @@ void svd_impl( } else { array& s = outputs[0]; - s.set_data(allocator::malloc_or_wait(s.nbytes())); + s.set_data(allocator::malloc(s.nbytes())); encoder.set_output_array(s); @@ -91,7 +91,7 @@ void svd_impl( // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). - auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; + auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; static const int lwork_query = -1; @@ -132,7 +132,7 @@ void svd_impl( } const int lwork = workspace_dimension; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Loop over matrices. for (int i = 0; i < num_matrices; i++) { diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index a12bbbf00..fa539541c 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) { } else { auto size = in.data_size(); out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), + allocator::malloc(size * out.itemsize()), size, in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } } diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 8f5b28226..d7b84a165 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) { return limit; }; -size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { +size_t MetalAllocator::set_memory_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, block_limit_); - relaxed_ = relaxed; gc_limit_ = std::min( block_limit_, static_cast(0.95 * device_->recommendedMaxWorkingSetSize())); return limit; }; +size_t MetalAllocator::get_memory_limit() { + return block_limit_; +} + size_t MetalAllocator::set_wired_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, wired_limit_); @@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) { return limit; }; -Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { +Buffer MetalAllocator::malloc(size_t size) { // Metal doesn't like empty buffers if (size == 0) { return Buffer{nullptr}; @@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { if (!buf) { size_t mem_required = get_active_memory() + get_cache_memory() + size; - // If there is too much memory pressure, fail (likely causes a wait). - if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { - return Buffer{nullptr}; - } - auto pool = metal::new_scoped_memory_pool(); // If we have a lot of memory pressure or are over the maximum cache size, @@ -328,8 +326,11 @@ MetalAllocator& allocator() { size_t set_cache_limit(size_t limit) { return allocator().set_cache_limit(limit); } -size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { - return allocator().set_memory_limit(limit, relaxed); +size_t set_memory_limit(size_t limit) { + return allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return allocator().get_memory_limit(); } size_t set_wired_limit(size_t limit) { if (limit > diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index df301f55e..8b77ff6c1 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -56,7 +56,7 @@ class BufferCache { class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) override; + virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; size_t get_active_memory() { @@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator { return buffer_cache_.cache_size(); }; size_t set_cache_limit(size_t limit); - size_t set_memory_limit(size_t limit, bool relaxed); + size_t set_memory_limit(size_t limit); + size_t get_memory_limit(); size_t set_wired_limit(size_t limit); void clear_cache(); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 3e42f7d2f..c4803a380 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu( Shape unfolded_shape{implicit_M, implicit_K}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::ostringstream kname; @@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu( // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K * groups}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::ostringstream kname; @@ -613,7 +613,7 @@ void winograd_conv_2D_gpu( // Do filter transform Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); - filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); + filt_wg.set_data(allocator::malloc(filt_wg.nbytes())); copies_w.push_back(filt_wg); { int bc = 32; @@ -640,7 +640,7 @@ void winograd_conv_2D_gpu( // Do input transform Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); - inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); + inp_wg.set_data(allocator::malloc(inp_wg.nbytes())); copies_w.push_back(inp_wg); { int bc = 32; @@ -667,7 +667,7 @@ void winograd_conv_2D_gpu( // Do batched gemm Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); - out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); + out_wg.set_data(allocator::malloc(out_wg.nbytes())); copies_w.push_back(out_wg); { std::vector empty_copies; @@ -855,7 +855,7 @@ void conv_3D_gpu( } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 1750a6908..3399201de 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index e775f6798..8a672289a 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -19,7 +19,7 @@ void CustomKernel::eval_gpu( copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index f5502231a..e784d34ae 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -20,7 +20,7 @@ struct FenceImpl { auto p = metal::new_scoped_memory_pool(); fence = static_cast(d->newSharedEvent()); } else { - auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); + auto buf = allocator::malloc(sizeof(uint32_t)).ptr(); fence = static_cast(buf); cpu_value()[0] = 0; } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 52131b4a8..95678279e 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -281,7 +281,7 @@ std::tuple compute_raders_constants( } array b_q_fft({rader_n - 1}, complex64, nullptr, {}); - b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes())); + b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes())); auto b_q_fft_ptr = reinterpret_cast*>(b_q_fft.data()); std::ptrdiff_t item_size = b_q_fft.itemsize(); @@ -327,11 +327,11 @@ std::pair compute_bluestein_constants(int n, int bluestein_n) { } array w_k({n}, complex64, nullptr, {}); - w_k.set_data(allocator::malloc_or_wait(w_k.nbytes())); + w_k.set_data(allocator::malloc(w_k.nbytes())); std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data()); array w_q({bluestein_n}, complex64, nullptr, {}); - w_q.set_data(allocator::malloc_or_wait(w_q.nbytes())); + w_q.set_data(allocator::malloc(w_q.nbytes())); auto w_q_ptr = reinterpret_cast*>(w_q.data()); @@ -551,8 +551,7 @@ void fft_op( flags.row_contiguous = is_row_contiguous; flags.contiguous = data_size == x_copy.size(); - x_copy.set_data( - allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags); + x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copies.push_back(x_copy); return x_copy; @@ -583,7 +582,7 @@ void fft_op( // TODO: allow donation here if (!inplace) { out.set_data( - allocator::malloc_or_wait(out.nbytes()), + allocator::malloc(out.nbytes()), out_data_size, out_strides, in_contiguous.flags()); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 7b711e28b..a7dfc5f17 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { if (in_contiguous.is_donatable()) { out.copy_shared_buffer(in_contiguous); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } int n, m; @@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { // Upload 2: // y = h12 @ tmp array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc_or_wait(temp.nbytes())); + temp.set_data(allocator::malloc(temp.nbytes())); copies.push_back(temp); launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 13696842c..d2a263051 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error(msg.str()); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; auto& idx = inputs[1]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4d3cf21ee..3f736505f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -382,7 +382,7 @@ void steel_matmul( int split_k_partition_size = gemm_k_iterations * bk; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; @@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int split_k_partition_size = gemm_k_iterations * bk; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; @@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d5c64f79d..82151c538 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -38,17 +38,20 @@ void reset_peak_memory(); size_t get_cache_memory(); /* Set the memory limit. - * Calls to malloc will wait on scheduled tasks if the limit is exceeded. If - * there are no more scheduled tasks an error will be raised if relaxed - * is false or memory will be allocated (including the potential for - * swap) if relaxed is true. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. * - * The memory limit defaults to 1.5 times the maximum recommended working set - * size reported by the device. + * When metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. * * Returns the previous memory limit. * */ -size_t set_memory_limit(size_t limit, bool relaxed = true); +size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +size_t get_memory_limit(); /* Set the free cache limit. * If using more than the given limit, free memory will be reclaimed diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index b9a82b876..c1d993d2a 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -29,7 +29,7 @@ void RMSNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); @@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc_or_wait(gx.nbytes())); + gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); @@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } - gw.set_data(allocator::malloc_or_wait(gw.nbytes())); + gw.set_data(allocator::malloc(gw.nbytes())); const int simd_size = 32; const int n_reads = RMS_N_READS; @@ -226,7 +226,7 @@ void LayerNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); @@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc_or_wait(gx.nbytes())); + gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); @@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } - gw.set_data(allocator::malloc_or_wait(gw.nbytes())); - gb.set_data(allocator::malloc_or_wait(gb.nbytes())); + gw.set_data(allocator::malloc(gw.nbytes())); + gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 76f7e3946..67576f03f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { void reshape(const array& in, array& out, Stream s) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); copy_gpu_inplace( in, out, @@ -58,7 +58,7 @@ static array compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + offset.set_data(allocator::malloc(offset.itemsize())); } d.add_temporary(offset, s.index); @@ -100,7 +100,7 @@ static array compute_dynamic_offset( void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); std::string op_name; @@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto& start = inputs[1]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto s = stream(); auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); copy_gpu_inplace( @@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector& inputs, array& out) { in, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + tmp.set_data(allocator::malloc(tmp.nbytes())); copy_gpu_inplace(in, tmp, CopyType::General, stream()); auto flags = out.flags(); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index d6fddd058..cc32797eb 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -224,7 +224,7 @@ void qvm_split_k( auto temp_shape = out.shape(); temp_shape.insert(temp_shape.end() - 2, split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); std::ostringstream kname; @@ -277,7 +277,7 @@ void qmm_op( int bits, bool gather, const Stream& s) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); MTL::Size group_dims; MTL::Size grid_dims; @@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu( std::vector& outputs) { auto& w_pre = inputs[0]; auto& out = outputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu( } else { auto& scales = outputs[1]; auto& biases = outputs[2]; - scales.set_data(allocator::malloc_or_wait(scales.nbytes())); - biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(scales, 2); compute_encoder.set_output_array(biases, 3); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 36e1266e6..c5650bdd7 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -347,7 +347,7 @@ void all_reduce_dispatch( // Allocate an intermediate tensor to hold results if needed array intermediate({n_rows}, out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // 1st pass @@ -641,7 +641,7 @@ void strided_reduce_longcolumn( intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel @@ -812,7 +812,7 @@ void strided_reduce_2pass( intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel @@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Minimum of 4 bytes since we use size 4 structs for all reduce // and metal will complain o/w size_t min_bytes = std::max(out.nbytes(), 4ul); - out.set_data(allocator::malloc_or_wait(min_bytes)); + out.set_data(allocator::malloc(min_bytes)); std::string op_name; switch (reduce_type_) { case Reduce::And: diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index ce3384c54..060758333 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -43,14 +43,14 @@ void RoPE::eval_gpu( donated = true; out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c5b544852..f64d057ce 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -248,9 +248,9 @@ void sdpa_vector_2pass( intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - sums.set_data(allocator::malloc_or_wait(sums.nbytes())); - maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + sums.set_data(allocator::malloc(sums.nbytes())); + maxs.set_data(allocator::malloc(maxs.nbytes())); d.add_temporary(intermediate, s.index); d.add_temporary(sums, s.index); d.add_temporary(maxs, s.index); @@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu( o.copy_shared_buffer(q); } else { if (o.shape(2) == 1) { - o.set_data(allocator::malloc_or_wait(o.nbytes())); + o.set_data(allocator::malloc(o.nbytes())); } else { auto strides = o.strides(); strides[2] = o.shape(1) * o.shape(3); @@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu( auto flags = q.flags(); flags.row_contiguous = q.shape(1) == 1; o.set_data( - allocator::malloc_or_wait(o.nbytes()), - o.size(), - std::move(strides), - flags); + allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); } } @@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu( }; o.set_data( - allocator::malloc_or_wait(o.nbytes()), + allocator::malloc(o.nbytes()), data_size, {str_oB, str_oH, str_oL, str_oD}, flags); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index dd123662f..c7e0087b7 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index d34e1a747..6ab08a108 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -29,7 +29,7 @@ void concatenate_gpu( } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 854f31f4b..b089188b8 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index ea16fbf89..543dfd180 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -150,12 +150,11 @@ void multi_block_sort( array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {}); // Do allocations - dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes())); - dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes())); - dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes())); - dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes())); - block_partitions.set_data( - allocator::malloc_or_wait(block_partitions.nbytes())); + dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes())); + dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes())); + dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes())); + dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes())); + block_partitions.set_data(allocator::malloc(block_partitions.nbytes())); std::vector copies = { dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; @@ -319,7 +318,7 @@ void gpu_merge_sort( void ArgSort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector& inputs, array& out) { void Sort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { // We direct arg partition to sort for now assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { // We direct partition to sort for now assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 741c7d70a..be43c41c2 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -97,13 +97,13 @@ void unary_op_gpu( out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } unary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 359e63b5d..03c68c734 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -42,7 +42,10 @@ void reset_peak_memory() {} size_t get_cache_memory() { return 0; } -size_t set_memory_limit(size_t, bool) { +size_t set_memory_limit(size_t) { + return 0; +} +size_t get_memory_limit() { return 0; } size_t set_cache_limit(size_t) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 958899bec..105a0fa28 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -218,7 +218,9 @@ array eval_impl(std::vector outputs, bool async) { cpu::eval(arr); } - if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) { + if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || + (metal::get_active_memory() > metal::get_memory_limit() && + scheduler::n_active_tasks() > 0)) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { @@ -226,6 +228,11 @@ array eval_impl(std::vector outputs, bool async) { } } scheduler::wait_for_one(); + // TODO memory api should be moved out of metal + while (metal::get_active_memory() > metal::get_memory_limit() && + scheduler::n_active_tasks() > 0) { + scheduler::wait_for_one(); + } } auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 5ef5691fb..fef856dd9 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -57,23 +57,19 @@ void init_metal(nb::module_& m) { "set_memory_limit", &mx::metal::set_memory_limit, "limit"_a, - nb::kw_only(), - "relaxed"_a = true, R"pbdoc( Set the memory limit. - Memory allocations will wait on scheduled tasks to complete if the limit - is exceeded. If there are no more scheduled tasks an error will be raised - if ``relaxed`` is ``False``. Otherwise memory will be allocated - (including the potential for swap) if ``relaxed`` is ``True``. + The memory limit is a guideline for the maximum amount of memory to use + during graph evaluation. If the memory limit is exceeded and there is no + more RAM (including swap when available) allocations will result in an + exception. - The memory limit defaults to 1.5 times the maximum recommended working set - size reported by the device. + When metal is available the memory limit defaults to 1.5 times the + maximum recommended working set size reported by the device. Args: limit (int): Memory limit in bytes. - relaxed (bool, optional): If `False`` an error is raised if the limit - is exceeded. Default: ``True`` Returns: int: The previous memory limit in bytes. diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 510402b06..ebcf64c7a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.abs(x, stream=s2) mx.eval(x) + s1 = mx.default_stream(mx.gpu) + s2 = mx.new_stream(mx.gpu) + old_limit = mx.metal.set_memory_limit(1000) + + x = mx.ones((512, 512), stream=s2) + for _ in range(80): + x = mx.abs(x, stream=s1) + y = mx.abs(x, stream=s2) + z = mx.abs(y, stream=s2) + mx.eval(z) + mx.metal.set_memory_limit(old_limit) + if __name__ == "__main__": unittest.main()