mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
MoE backward improvements (#2335)
This commit is contained in:

committed by
GitHub

parent
a4fcc893cd
commit
4a9b29a875
@@ -4321,6 +4321,28 @@ void init_ops(nb::module_& m) {
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"segmented_mm",
|
||||
&mx::segmented_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"segments"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform a matrix multiplication but segment the inner dimension and
|
||||
save the result for each segment separately.
|
||||
|
||||
Args:
|
||||
a (array): Input array of shape ``MxK``.
|
||||
b (array): Input array of shape ``KxN``.
|
||||
segments (array): The offsets into the inner dimension for each segment.
|
||||
|
||||
Returns:
|
||||
array: The result per segment of shape ``MxN``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tensordot",
|
||||
[](const mx::array& a,
|
||||
|
Reference in New Issue
Block a user