MoE backward improvements (#2335)

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:59:53 -07:00
committed by GitHub
parent a4fcc893cd
commit 4a9b29a875
22 changed files with 1130 additions and 60 deletions

View File

@@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
array: The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def(
"segmented_mm",
&mx::segmented_mm,
nb::arg(),
nb::arg(),
"segments"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform a matrix multiplication but segment the inner dimension and
save the result for each segment separately.
Args:
a (array): Input array of shape ``MxK``.
b (array): Input array of shape ``KxN``.
segments (array): The offsets into the inner dimension for each segment.
Returns:
array: The result per segment of shape ``MxN``.
)pbdoc");
m.def(
"tensordot",
[](const mx::array& a,