mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Start the segmented_mm op and CPU primitive
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -52,6 +53,53 @@ inline void mask_matrix(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const int32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
Shape a_copy = a_shape;
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
int32_t k_start = 0;
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
int32_t k_end = segments[elem_to_loc(i, segments_shape, segments_strides)];
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
a + k_start * a_strides[ndim - 1],
|
||||
b + k_start * b_strides[ndim - 2],
|
||||
out + i * M * N,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
1.0,
|
||||
0.0,
|
||||
1,
|
||||
a_copy,
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
k_start = k_end;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -437,4 +485,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto check_transpose = [&s, &encoder](const array& x) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
if (stx == x.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
} else {
|
||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, xc, CopyType::General, s);
|
||||
encoder.add_temporary(xc);
|
||||
int64_t stx = x.shape(-1);
|
||||
return std::make_tuple(false, stx, xc);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||
auto& segments = inputs[2];
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(segments);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
segments = array::unsafe_weak_copy(segments),
|
||||
out_ptr = out.data<void>(),
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb]() {
|
||||
switch (a.dtype()) {
|
||||
case float64:
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<int32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float32:
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<int32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float16:
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case bfloat16:
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<int32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size(),
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Segmented mm supports only real float types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1864,4 +1864,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::invalid_argument("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -4649,6 +4649,40 @@ array gather_mm(
|
||||
return axes.empty() ? out : squeeze(out, axes, s);
|
||||
}
|
||||
|
||||
array segmented_mm(
|
||||
array a,
|
||||
array b,
|
||||
array segments,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() != 2 || b.ndim() != 2) {
|
||||
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[segmented_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
Shape out_shape = segments.shape();
|
||||
out_shape.push_back(a.shape(0));
|
||||
out_shape.push_back(b.shape(1));
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<SegmentedMM>(to_stream(s)),
|
||||
{std::move(a), std::move(b), std::move(segments)});
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
||||
@@ -1406,6 +1406,12 @@ array gather_mm(
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute a matrix product but segment the inner dimension and write the
|
||||
* result separately for each segment.
|
||||
*/
|
||||
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
||||
@@ -526,6 +526,16 @@ class GatherMM : public UnaryPrimitive {
|
||||
bool right_sorted_;
|
||||
};
|
||||
|
||||
class SegmentedMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_PRINT(SegmentedMM)
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||
|
||||
@@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"segmented_mm",
|
||||
&mx::segmented_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"segments"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform a matrix multiplication but segment the inner dimension and
|
||||
save the result for each segment separately.
|
||||
|
||||
Args:
|
||||
a (array): Input array of shape ``MxK``.
|
||||
b (array): Input array of shape ``KxN``.
|
||||
segments (array): The offsets into the inner dimension for each segment.
|
||||
|
||||
Returns:
|
||||
array: The result per segment of shape ``MxN``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tensordot",
|
||||
[](const mx::array& a,
|
||||
|
||||
Reference in New Issue
Block a user