mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Start the segmented_mm op and CPU primitive
This commit is contained in:
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -4649,6 +4649,40 @@ array gather_mm(
|
||||
return axes.empty() ? out : squeeze(out, axes, s);
|
||||
}
|
||||
|
||||
array segmented_mm(
|
||||
array a,
|
||||
array b,
|
||||
array segments,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() != 2 || b.ndim() != 2) {
|
||||
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[segmented_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
Shape out_shape = segments.shape();
|
||||
out_shape.push_back(a.shape(0));
|
||||
out_shape.push_back(b.shape(1));
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<SegmentedMM>(to_stream(s)),
|
||||
{std::move(a), std::move(b), std::move(segments)});
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
||||
Reference in New Issue
Block a user