mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improve the gradient of gather_qmm as well
This commit is contained in:
@@ -3253,34 +3253,41 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
auto& lhs_indices = primals[4];
|
auto& lhs_indices = primals[4];
|
||||||
auto& rhs_indices = primals[5];
|
auto& rhs_indices = primals[5];
|
||||||
|
|
||||||
|
int M = cotan.shape(-2);
|
||||||
|
int N = cotan.shape(-1);
|
||||||
|
int K = x.shape(-1);
|
||||||
|
|
||||||
bool sorted = left_sorted_ || right_sorted_;
|
bool sorted = left_sorted_ || right_sorted_;
|
||||||
|
bool no_broadcast = rhs_indices.size() * M * K == x.size();
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
// gradient wrt to x
|
// gradient wrt to x
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
vjps.push_back(reshape(
|
auto g = gather_qmm(
|
||||||
scatter_add(
|
cotan,
|
||||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
w,
|
||||||
lhs_indices,
|
scales,
|
||||||
expand_dims(
|
biases,
|
||||||
gather_qmm(
|
std::nullopt,
|
||||||
cotan,
|
rhs_indices,
|
||||||
w,
|
!transpose_,
|
||||||
scales,
|
group_size_,
|
||||||
biases,
|
bits_,
|
||||||
std::nullopt,
|
sorted,
|
||||||
rhs_indices,
|
stream());
|
||||||
!transpose_,
|
if (sorted && no_broadcast) {
|
||||||
group_size_,
|
vjps.push_back(g);
|
||||||
bits_,
|
} else {
|
||||||
sorted,
|
vjps.push_back(reshape(
|
||||||
stream()),
|
scatter_add(
|
||||||
-3,
|
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||||
stream()),
|
lhs_indices,
|
||||||
0,
|
expand_dims(g, -3, stream()),
|
||||||
stream()),
|
0,
|
||||||
x.shape(),
|
stream()),
|
||||||
stream()));
|
x.shape(),
|
||||||
|
stream()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// gradient wrt to the indices is undefined
|
// gradient wrt to the indices is undefined
|
||||||
@@ -5064,6 +5071,8 @@ std::vector<array> GatherMM::vjp(
|
|||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
|
|
||||||
|
auto& a = primals[0];
|
||||||
|
auto& b = primals[1];
|
||||||
auto& lhs_indices = primals[2];
|
auto& lhs_indices = primals[2];
|
||||||
auto& rhs_indices = primals[3];
|
auto& rhs_indices = primals[3];
|
||||||
|
|
||||||
@@ -5076,23 +5085,26 @@ std::vector<array> GatherMM::vjp(
|
|||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
// M X N * (K X N).T -> M X K
|
auto g = gather_mm(
|
||||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
cotan,
|
||||||
|
swapaxes(b, -1, -2, stream()),
|
||||||
// g : (out_batch_shape) + (M, K)
|
std::nullopt,
|
||||||
auto g =
|
rhs_indices,
|
||||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
sorted,
|
||||||
|
stream());
|
||||||
if (sorted && no_broadcast) {
|
if (sorted && no_broadcast) {
|
||||||
vjps.push_back(g);
|
vjps.push_back(g);
|
||||||
} else {
|
} else {
|
||||||
g = expand_dims(g, -3, stream());
|
vjps.push_back(reshape(
|
||||||
auto base = zeros_like(primals[0], stream());
|
scatter_add(
|
||||||
auto base_shape = base.shape();
|
flatten(zeros_like(a, stream()), 0, -3, stream()),
|
||||||
base = reshape(base, {-1, M, K}, stream());
|
lhs_indices,
|
||||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
expand_dims(g, -3, stream()),
|
||||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
0,
|
||||||
|
stream()),
|
||||||
|
a.shape(),
|
||||||
|
stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (arg == 1) {
|
} else if (arg == 1) {
|
||||||
if (sorted) {
|
if (sorted) {
|
||||||
// Make the segments based on the rhs_indices
|
// Make the segments based on the rhs_indices
|
||||||
|
|||||||
Reference in New Issue
Block a user