diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 0be7c79ce..ff22329de 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" @@ -52,6 +53,53 @@ inline void mask_matrix( } } +template +inline void segmented_mm( + const T* a, + const T* b, + const int32_t* segments, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides, + size_t num_segments, + const Shape& segments_shape, + const Strides& segments_strides) { + int ndim = a_shape.size(); + Shape a_copy = a_shape; + Shape b_copy = b_shape; + int32_t M = a_copy[ndim - 2]; + int32_t N = b_copy[ndim - 1]; + int32_t k_start = 0; + for (int i = 0; i < num_segments; i++) { + int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)]; + a_copy[ndim - 1] = k_end - k_start; + b_copy[ndim - 2] = k_end - k_start; + matmul( + a + k_start * a_strides[ndim - 1], + b + k_start * b_strides[ndim - 2], + out + i * M * N, + a_transposed, + b_transposed, + lda, + ldb, + N, + 1.0, + 0.0, + 1, + a_copy, + a_strides, + b_copy, + b_strides); + k_start = k_end; + } +} + } // namespace void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { @@ -437,4 +485,121 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { encoder.add_temporaries(std::move(temps)); } +void SegmentedMM::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cpu::get_command_encoder(stream()); + auto check_transpose = [&s, &encoder](const array& x) { + auto stx = x.strides()[x.ndim() - 2]; + auto sty = x.strides()[x.ndim() - 1]; + if (stx == x.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, x); + } else if (stx == 1 && sty == x.shape(-2)) { + return std::make_tuple(true, sty, x); + } else { + array xc(x.shape(), x.dtype(), nullptr, {}); + copy(x, xc, CopyType::General, s); + encoder.add_temporary(xc); + int64_t stx = x.shape(-1); + return std::make_tuple(false, stx, xc); + } + }; + + auto [a_transposed, lda, a] = check_transpose(inputs[0]); + auto [b_transposed, ldb, b] = check_transpose(inputs[1]); + auto& segments = inputs[2]; + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(segments); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + segments = array::unsafe_weak_copy(segments), + out_ptr = out.data(), + a_transposed = a_transposed, + b_transposed = b_transposed, + lda = lda, + ldb = ldb]() { + switch (a.dtype()) { + case float64: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size(), + segments.shape(), + segments.strides()); + break; + case float32: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size(), + segments.shape(), + segments.strides()); + break; + case float16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size(), + segments.shape(), + segments.strides()); + break; + case bfloat16: + segmented_mm( + a.data(), + b.data(), + segments.data(), + static_cast(out_ptr), + a_transposed, + b_transposed, + lda, + ldb, + a.shape(), + a.strides(), + b.shape(), + b.strides(), + segments.size(), + segments.shape(), + segments.strides()); + break; + default: + throw std::invalid_argument( + "Segmented mm supports only real float types."); + } + }); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index be7f3e2f8..62a6f9caf 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1864,4 +1864,8 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s); } +void SegmentedMM::eval_gpu(const std::vector& inputs, array& out) { + throw std::invalid_argument("NYI"); +} + } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2b861428f..d5f73106d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4649,6 +4649,40 @@ array gather_mm( return axes.empty() ? out : squeeze(out, axes, s); } +array segmented_mm( + array a, + array b, + array segments, + StreamOrDevice s /* = {} */) { + if (a.ndim() != 2 || b.ndim() != 2) { + throw std::invalid_argument("[segmented_mm] Batched matmul not supported"); + } + + // Type promotion + auto out_type = result_type(a, b); + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[segmented_mm] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + + Shape out_shape = segments.shape(); + out_shape.push_back(a.shape(0)); + out_shape.push_back(b.shape(1)); + + return array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s)), + {std::move(a), std::move(b), std::move(segments)}); +} + array diagonal( const array& a, int offset /* = 0 */, diff --git a/mlx/ops.h b/mlx/ops.h index af3cdb5bd..596d6d287 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1406,6 +1406,12 @@ array gather_mm( bool sorted_indices = false, StreamOrDevice s = {}); +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +array segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + /** Extract a diagonal or construct a diagonal array */ array diagonal( const array& a, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4b18430ca..f4f157298 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive { bool right_sorted_; }; +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(SegmentedMM) +}; + class BroadcastAxes : public UnaryPrimitive { public: explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a1e77d681..d047f64cb 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "segmented_mm", + &mx::segmented_mm, + nb::arg(), + nb::arg(), + "segments"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform a matrix multiplication but segment the inner dimension and + save the result for each segment separately. + + Args: + a (array): Input array of shape ``MxK``. + b (array): Input array of shape ``KxN``. + segments (array): The offsets into the inner dimension for each segment. + + Returns: + array: The result per segment of shape ``MxN``. + )pbdoc"); m.def( "tensordot", [](const mx::array& a,