Start the segmented_mm op and CPU primitive

This commit is contained in:
Angelos Katharopoulos
2025-07-02 01:07:42 -07:00
parent e76e9b87f0
commit 6020ad6363
6 changed files with 241 additions and 0 deletions

View File

@@ -4649,6 +4649,40 @@ array gather_mm(
return axes.empty() ? out : squeeze(out, axes, s);
}
array segmented_mm(
array a,
array b,
array segments,
StreamOrDevice s /* = {} */) {
if (a.ndim() != 2 || b.ndim() != 2) {
throw std::invalid_argument("[segmented_mm] Batched matmul not supported");
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[segmented_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
Shape out_shape = segments.shape();
out_shape.push_back(a.shape(0));
out_shape.push_back(b.shape(1));
return array(
std::move(out_shape),
out_type,
std::make_shared<SegmentedMM>(to_stream(s)),
{std::move(a), std::move(b), std::move(segments)});
}
array diagonal(
const array& a,
int offset /* = 0 */,