Improve the gradient of gather_qmm as well

This commit is contained in:
Angelos Katharopoulos
2025-07-04 20:19:44 -07:00
parent b28577289e
commit bda1534a44

View File

@@ -3253,34 +3253,41 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5];
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = x.shape(-1);
bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == x.size();
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(
gather_qmm(
cotan,
w,
scales,
biases,
std::nullopt,
rhs_indices,
!transpose_,
group_size_,
bits_,
sorted,
stream()),
-3,
stream()),
0,
stream()),
x.shape(),
stream()));
auto g = gather_qmm(
cotan,
w,
scales,
biases,
std::nullopt,
rhs_indices,
!transpose_,
group_size_,
bits_,
sorted,
stream());
if (sorted && no_broadcast) {
vjps.push_back(g);
} else {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(g, -3, stream()),
0,
stream()),
x.shape(),
stream()));
}
}
// gradient wrt to the indices is undefined
@@ -5064,6 +5071,8 @@ std::vector<array> GatherMM::vjp(
std::vector<array> vjps;
auto& cotan = cotangents[0];
auto& a = primals[0];
auto& b = primals[1];
auto& lhs_indices = primals[2];
auto& rhs_indices = primals[3];
@@ -5076,23 +5085,26 @@ std::vector<array> GatherMM::vjp(
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto bt = swapaxes(primals[1], -1, -2, stream());
// g : (out_batch_shape) + (M, K)
auto g =
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
auto g = gather_mm(
cotan,
swapaxes(b, -1, -2, stream()),
std::nullopt,
rhs_indices,
sorted,
stream());
if (sorted && no_broadcast) {
vjps.push_back(g);
} else {
g = expand_dims(g, -3, stream());
auto base = zeros_like(primals[0], stream());
auto base_shape = base.shape();
base = reshape(base, {-1, M, K}, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(a, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(g, -3, stream()),
0,
stream()),
a.shape(),
stream()));
}
} else if (arg == 1) {
if (sorted) {
// Make the segments based on the rhs_indices