mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
MoE backward improvements (#2335)
This commit is contained in:
committed by
GitHub
parent
a4fcc893cd
commit
4a9b29a875
@@ -1406,6 +1406,12 @@ array gather_mm(
|
||||
bool sorted_indices = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute a matrix product but segment the inner dimension and write the
|
||||
* result separately for each segment.
|
||||
*/
|
||||
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
||||
Reference in New Issue
Block a user