Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos
2025-04-14 16:37:36 -07:00
committed by GitHub
parent e9e268336b
commit 99eefd2ec0
23 changed files with 1260 additions and 378 deletions

View File

@@ -4895,6 +4895,8 @@ std::vector<array> GatherMM::vjp(
int N = cotan.shape(-1);
int K = primals[0].shape(-1);
bool sorted = left_sorted_ || right_sorted_;
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
@@ -4905,7 +4907,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K)
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
@@ -4920,7 +4923,8 @@ std::vector<array> GatherMM::vjp(
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
@@ -4933,6 +4937,12 @@ std::vector<array> GatherMM::vjp(
return vjps;
}
bool GatherMM::is_equivalent(const Primitive& other) const {
const GatherMM& g_other = static_cast<const GatherMM&>(other);
return left_sorted_ == g_other.left_sorted_ &&
right_sorted_ == g_other.right_sorted_;
}
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
return (block_size_ == a_other.block_size_);