Improve the gradient of gather_qmm as well

This commit is contained in:
Angelos Katharopoulos
2025-07-04 20:19:44 -07:00
parent b28577289e
commit bda1534a44

View File

@@ -3253,17 +3253,17 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4]; auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5]; auto& rhs_indices = primals[5];
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = x.shape(-1);
bool sorted = left_sorted_ || right_sorted_; bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == x.size();
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
if (arg == 0) { if (arg == 0) {
vjps.push_back(reshape( auto g = gather_qmm(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(
gather_qmm(
cotan, cotan,
w, w,
scales, scales,
@@ -3274,14 +3274,21 @@ std::vector<array> GatherQMM::vjp(
group_size_, group_size_,
bits_, bits_,
sorted, sorted,
stream()), stream());
-3, if (sorted && no_broadcast) {
stream()), vjps.push_back(g);
} else {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(g, -3, stream()),
0, 0,
stream()), stream()),
x.shape(), x.shape(),
stream())); stream()));
} }
}
// gradient wrt to the indices is undefined // gradient wrt to the indices is undefined
else if (arg > 3) { else if (arg > 3) {
@@ -5064,6 +5071,8 @@ std::vector<array> GatherMM::vjp(
std::vector<array> vjps; std::vector<array> vjps;
auto& cotan = cotangents[0]; auto& cotan = cotangents[0];
auto& a = primals[0];
auto& b = primals[1];
auto& lhs_indices = primals[2]; auto& lhs_indices = primals[2];
auto& rhs_indices = primals[3]; auto& rhs_indices = primals[3];
@@ -5076,23 +5085,26 @@ std::vector<array> GatherMM::vjp(
for (auto arg : argnums) { for (auto arg : argnums) {
if (arg == 0) { if (arg == 0) {
// M X N * (K X N).T -> M X K auto g = gather_mm(
auto bt = swapaxes(primals[1], -1, -2, stream()); cotan,
swapaxes(b, -1, -2, stream()),
// g : (out_batch_shape) + (M, K) std::nullopt,
auto g = rhs_indices,
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); sorted,
stream());
if (sorted && no_broadcast) { if (sorted && no_broadcast) {
vjps.push_back(g); vjps.push_back(g);
} else { } else {
g = expand_dims(g, -3, stream()); vjps.push_back(reshape(
auto base = zeros_like(primals[0], stream()); scatter_add(
auto base_shape = base.shape(); flatten(zeros_like(a, stream()), 0, -3, stream()),
base = reshape(base, {-1, M, K}, stream()); lhs_indices,
auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); expand_dims(g, -3, stream()),
vjps.push_back(reshape(gacc, base_shape, stream())); 0,
stream()),
a.shape(),
stream()));
} }
} else if (arg == 1) { } else if (arg == 1) {
if (sorted) { if (sorted) {
// Make the segments based on the rhs_indices // Make the segments based on the rhs_indices