mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Gather mm new kernel and small refactoring (#2040)
This commit is contained in:
committed by
GitHub
parent
e9e268336b
commit
99eefd2ec0
@@ -4895,6 +4895,8 @@ std::vector<array> GatherMM::vjp(
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
@@ -4905,7 +4907,8 @@ std::vector<array> GatherMM::vjp(
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g = gather_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
|
||||
@@ -4920,7 +4923,8 @@ std::vector<array> GatherMM::vjp(
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g = gather_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
@@ -4933,6 +4937,12 @@ std::vector<array> GatherMM::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
bool GatherMM::is_equivalent(const Primitive& other) const {
|
||||
const GatherMM& g_other = static_cast<const GatherMM&>(other);
|
||||
return left_sorted_ == g_other.left_sorted_ &&
|
||||
right_sorted_ == g_other.right_sorted_;
|
||||
}
|
||||
|
||||
bool BlockMaskedMM::is_equivalent(const Primitive& other) const {
|
||||
const BlockMaskedMM& a_other = static_cast<const BlockMaskedMM&>(other);
|
||||
return (block_size_ == a_other.block_size_);
|
||||
|
||||
Reference in New Issue
Block a user