diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index 78ef0f212..522e9653f 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -65,6 +65,12 @@ template batch_ndim]; + if(has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } } // Adjust for batch diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 1019f6737..fb72b8bcd 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1122,7 +1122,38 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions - auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); + bool has_op_mask = inputs.size() > 3; + auto& out_mask = inputs[2]; + + std::vector batch_shape{1}; + int A_batch_str = 0; + int B_batch_str = 0; + + std::vector batch_strides; + + if (out.ndim() > 2) { + auto get_batch_dims = [](const auto& v) { + return decltype(v){v.begin(), v.end() - 2}; + }; + + std::vector bshape{out.shape().begin(), out.shape().end() - 2}; + std::vector> bstrides; + + for (auto& arr : inputs) { + bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2); + } + + auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); + batch_shape = bshape_c; + A_batch_str = int(bstrides_c[0].back()); + B_batch_str = int(bstrides_c[1].back()); + + for (auto& bstr : bstrides_c) { + batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end()); + } + } else { + batch_strides = std::vector(inputs.size(), 0); + } auto batch_size_out = out.size() / (M * N); int matrix_stride_out = M * N; @@ -1142,7 +1173,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { << "_wm" << wm << "_wn" << wn << "_MN_" << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_" - << (inputs.size() > 3 ? "T" : "N"); + << (has_op_mask ? "T" : "N"); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -1166,8 +1197,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const int batch_stride_a = */ int(A_batch_stride.back()), - /* const int batch_stride_b = */ int(B_batch_stride.back()), + /* const int batch_stride_a = */ A_batch_str, + /* const int batch_stride_b = */ B_batch_str, /* const int batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), @@ -1181,42 +1212,21 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - std::vector batch_strides = A_batch_stride; - batch_strides.insert( - batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - std::vector mask_strides; - - auto& out_mask = inputs[2]; mask_strides.push_back(*(out_mask.strides().end() - 1)); mask_strides.push_back(*(out_mask.strides().end() - 2)); - batch_strides.insert( - batch_strides.end(), - out_mask.strides().begin(), - out_mask.strides().end() - 2); - - if (inputs.size() > 3) { + if (has_op_mask) { auto& lhs_mask = inputs[3]; mask_strides.push_back(*(lhs_mask.strides().end() - 1)); mask_strides.push_back(*(lhs_mask.strides().end() - 2)); - batch_strides.insert( - batch_strides.end(), - lhs_mask.strides().begin(), - lhs_mask.strides().end() - 2); - compute_encoder.set_input_array(lhs_mask, 11); auto& rhs_mask = inputs[4]; mask_strides.push_back(*(rhs_mask.strides().end() - 1)); mask_strides.push_back(*(rhs_mask.strides().end() - 2)); - batch_strides.insert( - batch_strides.end(), - rhs_mask.strides().begin(), - rhs_mask.strides().end() - 2); - compute_encoder.set_input_array(rhs_mask, 12); } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index ef435e51e..9f24d294d 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -717,7 +717,16 @@ class TestBlas(mlx_tests.MLXTestCase): out = out * out_mask return out - def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32): + def test_shape( + M, + N, + K, + block_size, + transpose=False, + np_dtype=np.float32, + batch_A=(), + batch_B=(), + ): with self.subTest( M=M, N=N, @@ -725,28 +734,32 @@ class TestBlas(mlx_tests.MLXTestCase): block_size=block_size, np_dtype=np_dtype, transpose=transpose, + batch_A=batch_A, + batch_B=batch_B, ): tm = (M + block_size - 1) // block_size tn = (N + block_size - 1) // block_size tk = (K + block_size - 1) // block_size - a_np = np.random.normal(size=(M, K)).astype(np_dtype) - b_np = np.random.normal(size=(K, N)).astype(np_dtype) + a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype) + b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype) - a_np_mask = np.random.normal(size=(tm, tk)) < 0.0 - b_np_mask = np.random.normal(size=(tk, tn)) < 0.0 - out_np_mask = np.random.normal(size=(tm, tn)) < 0.0 + batch_out = np.broadcast_shapes(batch_A, batch_B) + + a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0 + b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0 + out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0 a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map( mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask) ) if transpose: - b_np = np.random.normal(size=(N, K)).astype(np_dtype) + b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype) b_mx = mx.array(b_np) - b_np = b_np.T - b_mx = b_mx.T + b_np = np.swapaxes(b_np, -2, -1) + b_mx = mx.swapaxes(b_mx, -2, -1) out_np = np_block_masked_mm( a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask @@ -779,6 +792,9 @@ class TestBlas(mlx_tests.MLXTestCase): test_shape(M, N, K, block_size, transpose=False) test_shape(M, N, K, block_size, transpose=True) + # Test broadcasting + test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2)) + # Test gemv a_np = np.random.normal(size=(64, 64)).astype(np.float32) b_np = np.random.normal(size=(64,)).astype(np.float32)