mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improve the gradient of gather_qmm as well
This commit is contained in:
@@ -3253,34 +3253,41 @@ std::vector<array> GatherQMM::vjp(
|
||||
auto& lhs_indices = primals[4];
|
||||
auto& rhs_indices = primals[5];
|
||||
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = x.shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
bool no_broadcast = rhs_indices.size() * M * K == x.size();
|
||||
|
||||
for (auto arg : argnums) {
|
||||
// gradient wrt to x
|
||||
if (arg == 0) {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(
|
||||
gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
auto g = gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
|
||||
// gradient wrt to the indices is undefined
|
||||
@@ -5064,6 +5071,8 @@ std::vector<array> GatherMM::vjp(
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto& a = primals[0];
|
||||
auto& b = primals[1];
|
||||
auto& lhs_indices = primals[2];
|
||||
auto& rhs_indices = primals[3];
|
||||
|
||||
@@ -5076,23 +5085,26 @@ std::vector<array> GatherMM::vjp(
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
auto g = gather_mm(
|
||||
cotan,
|
||||
swapaxes(b, -1, -2, stream()),
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(a, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
a.shape(),
|
||||
stream()));
|
||||
}
|
||||
|
||||
} else if (arg == 1) {
|
||||
if (sorted) {
|
||||
// Make the segments based on the rhs_indices
|
||||
|
||||
Reference in New Issue
Block a user