diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index ff22329de..6fcf25b15 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -77,7 +77,10 @@ inline void segmented_mm( int32_t N = b_copy[ndim - 1]; int32_t k_start = 0; for (int i = 0; i < num_segments; i++) { - int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)]; + int32_t k_start = + segments[elem_to_loc(2 * i, segments_shape, segments_strides)]; + int32_t k_end = + segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)]; a_copy[ndim - 1] = k_end - k_start; b_copy[ndim - 2] = k_end - k_start; matmul( @@ -96,7 +99,6 @@ inline void segmented_mm( a_strides, b_copy, b_strides); - k_start = k_end; } } @@ -537,7 +539,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { a.strides(), b.shape(), b.strides(), - segments.size(), + segments.size() / 2, segments.shape(), segments.strides()); break; @@ -555,7 +557,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { a.strides(), b.shape(), b.strides(), - segments.size(), + segments.size() / 2, segments.shape(), segments.strides()); break; @@ -573,7 +575,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { a.strides(), b.shape(), b.strides(), - segments.size(), + segments.size() / 2, segments.shape(), segments.strides()); break; @@ -591,7 +593,7 @@ void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { a.strides(), b.shape(), b.strides(), - segments.size(), + segments.size() / 2, segments.shape(), segments.strides()); break; diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 62a6f9caf..10d697635 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1864,8 +1864,138 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } +void segmented_mm( + const array& a_, + const array& b_, + const array& segments_, + array& out, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Copy if needed + std::vector copies; + auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false); + auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false); + auto segments = ensure_row_contiguous(segments_, d, s); + d.add_temporaries(std::move(copies), s.index); + + // Determine dispatch kernel + int bm = 64, bn = 64, bk = 16; + int wm = 2, wn = 2; + size_t batch_size_out = out.size() / M / N; + + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(128); + concatenate( + base_name, + "steel_segmented_mm_", + transpose_a ? 't' : 'n', + transpose_b ? 't' : 'n', + "_", + type_to_name(a), + "_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn, + false); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ static_cast(lda), + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ 0, + /* const int64_t batch_stride_d = */ M * N, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ 0, + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = + MTL::Size(params.tiles_n, params.tiles_m, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(segments, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { - throw std::invalid_argument("NYI"); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& segments = inputs[2]; + + out.set_data(allocator::malloc(out.nbytes())); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + segmented_mm(a, b, segments, out, M, N, K, d, s); } } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d5f73106d..e35241c03 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4658,6 +4658,13 @@ array segmented_mm( throw std::invalid_argument("[segmented_mm] Batched matmul not supported"); } + if (segments.ndim() < 1 || segments.shape().back() != 2) { + std::ostringstream msg; + msg << "[segmented_mm] The segments should have shape (..., 2) but " + << segments.shape() << " was provided."; + throw std::invalid_argument(msg.str()); + } + // Type promotion auto out_type = result_type(a, b); if (!issubdtype(out_type, floating)) { @@ -4673,6 +4680,7 @@ array segmented_mm( b = astype(b, out_type, s); Shape out_shape = segments.shape(); + out_shape.pop_back(); out_shape.push_back(a.shape(0)); out_shape.push_back(b.shape(1));