MoE backward improvements (#2335)

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:59:53 -07:00
committed by GitHub
parent a4fcc893cd
commit 4a9b29a875
22 changed files with 1130 additions and 60 deletions

View File

@@ -1406,6 +1406,12 @@ array gather_mm(
bool sorted_indices = false,
StreamOrDevice s = {});
/**
* Compute a matrix product but segment the inner dimension and write the
* result separately for each segment.
*/
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,