Add gradient for the scales and biases in gather qmm

This commit is contained in:
Angelos Katharopoulos
2025-07-05 00:58:17 -07:00
parent bda1534a44
commit 3d4174cd37

View File

@@ -109,6 +109,70 @@ std::tuple<array, array, array, int> vmap_ternary_op(
return {a, b, c, to_ax}; return {a, b, c, to_ax};
} }
// Calculate the gradient wrt to the weights of the following calculation
//
// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)
//
// Note the transpose above. This function returns the gradient for w.T so if w
// was used instead then one needs to transpose the returned gradient.
//
// We define it as a separate function to reuse it for gather_mm and
// gather_qmm.
array gather_mm_grad(
const array& x,
const array& dy,
const array& lhs_indices,
const array& rhs_indices,
bool sorted,
Shape batch_shape,
const Stream& s) {
int M = x.shape(-2);
int K = x.shape(-1);
int N = dy.shape(-1);
int num_segments = std::accumulate(
batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
batch_shape.push_back(N);
batch_shape.push_back(K);
// If the indices are sorted then it means that we can do the whole gradient
// computation via a segmented matmul. We just need to calculate the segments
// using the indices.
if (sorted) {
auto segments = zeros({num_segments}, uint32, s);
segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);
segments = cumsum(segments, 0, false, true, s);
segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);
return reshape(
segmented_mm(
swapaxes(flatten(dy, 0, -2, s), 0, 1, s),
flatten(x, 0, -2, s),
segments,
s),
std::move(batch_shape),
s);
}
// Otherwise we need to gather matmul the dy and then scatter add it to the
// correct locations.
else {
// TODO: If the lhs indices wasn't provided, this is always a sorted matmul
// so we should add that check.
auto dw = gather_mm(
swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);
return reshape(
scatter_add(
zeros({num_segments, N, K}, dw.dtype(), s),
rhs_indices,
expand_dims(dw, -3, s),
0,
s),
std::move(batch_shape),
s);
}
}
} // namespace } // namespace
std::vector<array> Primitive::jvp( std::vector<array> Primitive::jvp(
@@ -3181,7 +3245,6 @@ std::vector<array> QuantizedMatmul::vjp(
vjps.push_back(sum(*dsb, -1, false, stream())); vjps.push_back(sum(*dsb, -1, false, stream()));
} else { } else {
// scales // scales
auto s = stream();
auto wq = dequantize( auto wq = dequantize(
primals[1], primals[1],
ones_like(primals[2], stream()), ones_like(primals[2], stream()),
@@ -3259,6 +3322,7 @@ std::vector<array> GatherQMM::vjp(
bool sorted = left_sorted_ || right_sorted_; bool sorted = left_sorted_ || right_sorted_;
bool no_broadcast = rhs_indices.size() * M * K == x.size(); bool no_broadcast = rhs_indices.size() * M * K == x.size();
std::optional<array> dsb = std::nullopt;
for (auto arg : argnums) { for (auto arg : argnums) {
// gradient wrt to x // gradient wrt to x
@@ -3297,9 +3361,45 @@ std::vector<array> GatherQMM::vjp(
} }
// gradient wrt to w_q, scales or biases // gradient wrt to w_q, scales or biases
else { else if (arg == 1) {
throw std::runtime_error( throw std::runtime_error(
"GatherQMM::vjp no gradient wrt the quantized matrix yet."); "GatherQMM::vjp no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto shape = w.shape();
shape.pop_back();
shape.pop_back();
dsb = unflatten(
gather_mm_grad(
x,
cotan,
lhs_indices,
rhs_indices,
sorted,
std::move(shape),
stream()),
-1,
{-1, group_size_},
stream());
}
if (arg == 3) {
vjps.push_back(sum(*dsb, -1, false, stream()));
} else {
vjps.push_back(
sum(multiply(
*dsb,
dequantize(
w,
ones_like(scales, stream()),
zeros_like(biases, stream()),
group_size_,
bits_,
stream()),
stream()),
-1,
false,
stream()));
}
} }
} }
return vjps; return vjps;
@@ -5106,46 +5206,21 @@ std::vector<array> GatherMM::vjp(
stream())); stream()));
} }
} else if (arg == 1) { } else if (arg == 1) {
if (sorted) { auto shape = b.shape();
// Make the segments based on the rhs_indices shape.pop_back();
int num_segments = primals[1].size() / K / N; shape.pop_back();
auto segments = zeros({num_segments}, uint32, stream()); vjps.push_back(swapaxes(
segments = scatter_add_axis( gather_mm_grad(
segments, rhs_indices, array(M, uint32), 0, stream()); a,
segments = cumsum(segments, 0, false, true, stream()); cotan,
segments = lhs_indices,
concatenate({array({0}, {1}, uint32), segments}, 0, stream()); rhs_indices,
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream()); sorted,
std::move(shape),
// Reshape and transpose the inputs such that they are a big segmented stream()),
// matmul. -1,
auto a = reshape(primals[0], {-1, K}, stream()); -2,
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream()); stream()));
// Calculate the gradient.
// Since the gather mm is often used as x @ w.T we will calculate the
// gradient as c @ a and transpose it before returning it which should
// save a copy in that case.
auto g = segmented_mm(c, a, segments, stream());
g = swapaxes(g, 1, 2, stream());
vjps.push_back(reshape(g, primals[1].shape(), stream()));
} else {
// (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g =
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
}
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[GatherMM] Cannot calculate VJP with respect to indices."); "[GatherMM] Cannot calculate VJP with respect to indices.");