Improve the gradient of gather_qmm as well

This commit is contained in:
Angelos Katharopoulos
2025-07-04 20:19:44 -07:00
parent b28577289e
commit bda1534a44

View File

@@ -3253,34 +3253,41 @@ std::vector<array> GatherQMM::vjp(
auto& lhs_indices = primals[4]; auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5]; auto& rhs_indices = primals[5];
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = x.shape(-1);
bool sorted = left_sorted_ || right_sorted_; bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == x.size();
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
if (arg == 0) { if (arg == 0) {
vjps.push_back(reshape( auto g = gather_qmm(
scatter_add( cotan,
flatten(zeros_like(x, stream()), 0, -3, stream()), w,
lhs_indices, scales,
expand_dims( biases,
gather_qmm( std::nullopt,
cotan, rhs_indices,
w, !transpose_,
scales, group_size_,
biases, bits_,
std::nullopt, sorted,
rhs_indices, stream());
!transpose_, if (sorted && no_broadcast) {
group_size_, vjps.push_back(g);
bits_, } else {
sorted, vjps.push_back(reshape(
stream()), scatter_add(
-3, flatten(zeros_like(x, stream()), 0, -3, stream()),
stream()), lhs_indices,
0, expand_dims(g, -3, stream()),
stream()), 0,
x.shape(), stream()),
stream())); x.shape(),
stream()));
}
} }
// gradient wrt to the indices is undefined // gradient wrt to the indices is undefined
@@ -5064,6 +5071,8 @@ std::vector<array> GatherMM::vjp(
std::vector<array> vjps; std::vector<array> vjps;
auto& cotan = cotangents[0]; auto& cotan = cotangents[0];
auto& a = primals[0];
auto& b = primals[1];
auto& lhs_indices = primals[2]; auto& lhs_indices = primals[2];
auto& rhs_indices = primals[3]; auto& rhs_indices = primals[3];
@@ -5076,23 +5085,26 @@ std::vector<array> GatherMM::vjp(
for (auto arg : argnums) { for (auto arg : argnums) {
if (arg == 0) { if (arg == 0) {
// M X N * (K X N).T -> M X K auto g = gather_mm(
auto bt = swapaxes(primals[1], -1, -2, stream()); cotan,
swapaxes(b, -1, -2, stream()),
// g : (out_batch_shape) + (M, K) std::nullopt,
auto g = rhs_indices,
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream()); sorted,
stream());
if (sorted && no_broadcast) { if (sorted && no_broadcast) {
vjps.push_back(g); vjps.push_back(g);
} else { } else {
g = expand_dims(g, -3, stream()); vjps.push_back(reshape(
auto base = zeros_like(primals[0], stream()); scatter_add(
auto base_shape = base.shape(); flatten(zeros_like(a, stream()), 0, -3, stream()),
base = reshape(base, {-1, M, K}, stream()); lhs_indices,
auto gacc = scatter_add(base, lhs_indices, g, 0, stream()); expand_dims(g, -3, stream()),
vjps.push_back(reshape(gacc, base_shape, stream())); 0,
stream()),
a.shape(),
stream()));
} }
} else if (arg == 1) { } else if (arg == 1) {
if (sorted) { if (sorted) {
// Make the segments based on the rhs_indices // Make the segments based on the rhs_indices