Start the segmented_mm op and CPU primitive

This commit is contained in:
Angelos Katharopoulos
2025-07-02 01:07:42 -07:00
parent e76e9b87f0
commit 6020ad6363
6 changed files with 241 additions and 0 deletions

View File

@@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"segmented_mm",
&mx::segmented_mm,
nb::arg(),
nb::arg(),
"segments"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform a matrix multiplication but segment the inner dimension and
save the result for each segment separately.
Args:
a (array): Input array of shape ``MxK``.
b (array): Input array of shape ``KxN``.
segments (array): The offsets into the inner dimension for each segment.
Returns:
array: The result per segment of shape ``MxN``.
)pbdoc");
m.def(
"tensordot",
[](const mx::array& a,