Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos
2025-04-14 16:37:36 -07:00
committed by GitHub
parent e9e268336b
commit 99eefd2ec0
23 changed files with 1260 additions and 378 deletions

View File

@@ -4499,6 +4499,7 @@ array gather_mm(
array b,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
bool sorted_indices /* = false */,
StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) {
@@ -4574,12 +4575,18 @@ array gather_mm(
out_shape.push_back(M);
out_shape.push_back(N);
// Caculate array
// Make the output array
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
std::make_shared<GatherMM>(
to_stream(s),
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{std::move(a),
std::move(b),
std::move(lhs_indices),
std::move(rhs_indices)});
// Remove the possibly inserted singleton dimensions
std::vector<int> axes;