mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
bd0622c4d9
...
3d4174cd37
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d4174cd37 | ||
|
|
bda1534a44 | ||
|
|
b28577289e | ||
|
|
2d0f452aae |
@@ -81,6 +81,7 @@ inline void segmented_mm(
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
if (k_end <= k_start) {
|
||||
std::fill_n(out + i * M * N, M * N, T(0));
|
||||
continue;
|
||||
}
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
|
||||
@@ -109,6 +109,70 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
return {a, b, c, to_ax};
|
||||
}
|
||||
|
||||
// Calculate the gradient wrt to the weights of the following calculation
|
||||
//
|
||||
// y = gather_mm(x, w.T, lhs_indices, rhs_indices, sorted)
|
||||
//
|
||||
// Note the transpose above. This function returns the gradient for w.T so if w
|
||||
// was used instead then one needs to transpose the returned gradient.
|
||||
//
|
||||
// We define it as a separate function to reuse it for gather_mm and
|
||||
// gather_qmm.
|
||||
array gather_mm_grad(
|
||||
const array& x,
|
||||
const array& dy,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool sorted,
|
||||
Shape batch_shape,
|
||||
const Stream& s) {
|
||||
int M = x.shape(-2);
|
||||
int K = x.shape(-1);
|
||||
int N = dy.shape(-1);
|
||||
int num_segments = std::accumulate(
|
||||
batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int>());
|
||||
batch_shape.push_back(N);
|
||||
batch_shape.push_back(K);
|
||||
|
||||
// If the indices are sorted then it means that we can do the whole gradient
|
||||
// computation via a segmented matmul. We just need to calculate the segments
|
||||
// using the indices.
|
||||
if (sorted) {
|
||||
auto segments = zeros({num_segments}, uint32, s);
|
||||
segments = scatter_add_axis(segments, rhs_indices, array(M, uint32), 0, s);
|
||||
segments = cumsum(segments, 0, false, true, s);
|
||||
segments = concatenate({array({0}, {1}, uint32), segments}, 0, s);
|
||||
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, s);
|
||||
|
||||
return reshape(
|
||||
segmented_mm(
|
||||
swapaxes(flatten(dy, 0, -2, s), 0, 1, s),
|
||||
flatten(x, 0, -2, s),
|
||||
segments,
|
||||
s),
|
||||
std::move(batch_shape),
|
||||
s);
|
||||
}
|
||||
|
||||
// Otherwise we need to gather matmul the dy and then scatter add it to the
|
||||
// correct locations.
|
||||
else {
|
||||
// TODO: If the lhs indices wasn't provided, this is always a sorted matmul
|
||||
// so we should add that check.
|
||||
auto dw = gather_mm(
|
||||
swapaxes(dy, -1, -2, s), x, std::nullopt, lhs_indices, false, s);
|
||||
return reshape(
|
||||
scatter_add(
|
||||
zeros({num_segments, N, K}, dw.dtype(), s),
|
||||
rhs_indices,
|
||||
expand_dims(dw, -3, s),
|
||||
0,
|
||||
s),
|
||||
std::move(batch_shape),
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
@@ -3181,7 +3245,6 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
// scales
|
||||
auto s = stream();
|
||||
auto wq = dequantize(
|
||||
primals[1],
|
||||
ones_like(primals[2], stream()),
|
||||
@@ -3253,34 +3316,42 @@ std::vector<array> GatherQMM::vjp(
|
||||
auto& lhs_indices = primals[4];
|
||||
auto& rhs_indices = primals[5];
|
||||
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = x.shape(-1);
|
||||
|
||||
bool sorted = left_sorted_ || right_sorted_;
|
||||
bool no_broadcast = rhs_indices.size() * M * K == x.size();
|
||||
std::optional<array> dsb = std::nullopt;
|
||||
|
||||
for (auto arg : argnums) {
|
||||
// gradient wrt to x
|
||||
if (arg == 0) {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(
|
||||
gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
auto g = gather_qmm(
|
||||
cotan,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
x.shape(),
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
|
||||
// gradient wrt to the indices is undefined
|
||||
@@ -3290,9 +3361,45 @@ std::vector<array> GatherQMM::vjp(
|
||||
}
|
||||
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else {
|
||||
else if (arg == 1) {
|
||||
throw std::runtime_error(
|
||||
"GatherQMM::vjp no gradient wrt the quantized matrix yet.");
|
||||
"GatherQMM::vjp no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto shape = w.shape();
|
||||
shape.pop_back();
|
||||
shape.pop_back();
|
||||
dsb = unflatten(
|
||||
gather_mm_grad(
|
||||
x,
|
||||
cotan,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
std::move(shape),
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
stream());
|
||||
}
|
||||
if (arg == 3) {
|
||||
vjps.push_back(sum(*dsb, -1, false, stream()));
|
||||
} else {
|
||||
vjps.push_back(
|
||||
sum(multiply(
|
||||
*dsb,
|
||||
dequantize(
|
||||
w,
|
||||
ones_like(scales, stream()),
|
||||
zeros_like(biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
stream()),
|
||||
stream()),
|
||||
-1,
|
||||
false,
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
@@ -5064,6 +5171,8 @@ std::vector<array> GatherMM::vjp(
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto& a = primals[0];
|
||||
auto& b = primals[1];
|
||||
auto& lhs_indices = primals[2];
|
||||
auto& rhs_indices = primals[3];
|
||||
|
||||
@@ -5076,64 +5185,42 @@ std::vector<array> GatherMM::vjp(
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g =
|
||||
gather_mm(cotan, bt, std::nullopt, rhs_indices, sorted, stream());
|
||||
auto g = gather_mm(
|
||||
cotan,
|
||||
swapaxes(b, -1, -2, stream()),
|
||||
std::nullopt,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
stream());
|
||||
if (sorted && no_broadcast) {
|
||||
vjps.push_back(g);
|
||||
} else {
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
vjps.push_back(reshape(
|
||||
scatter_add(
|
||||
flatten(zeros_like(a, stream()), 0, -3, stream()),
|
||||
lhs_indices,
|
||||
expand_dims(g, -3, stream()),
|
||||
0,
|
||||
stream()),
|
||||
a.shape(),
|
||||
stream()));
|
||||
}
|
||||
|
||||
} else if (arg == 1) {
|
||||
if (sorted) {
|
||||
// Make the segments based on the rhs_indices
|
||||
int num_segments = primals[1].size() / K / N;
|
||||
auto segments = zeros({num_segments}, uint32, stream());
|
||||
segments = scatter_add_axis(
|
||||
segments, rhs_indices, array(M, uint32), 0, stream());
|
||||
segments = cumsum(segments, 0, false, true, stream());
|
||||
segments =
|
||||
concatenate({array({0}, {1}, uint32), segments}, 0, stream());
|
||||
segments = as_strided(segments, {num_segments, 2}, {1, 1}, 0, stream());
|
||||
|
||||
// Reshape and transpose the inputs such that they are a big segmented
|
||||
// matmul.
|
||||
auto a = reshape(primals[0], {-1, K}, stream());
|
||||
auto c = swapaxes(reshape(cotan, {-1, N}, stream()), 0, 1, stream());
|
||||
|
||||
// Calculate the gradient.
|
||||
// Since the gather mm is often used as x @ w.T we will calculate the
|
||||
// gradient as c @ a and transpose it before returning it which should
|
||||
// save a copy in that case.
|
||||
auto g = segmented_mm(c, a, segments, stream());
|
||||
g = swapaxes(g, 1, 2, stream());
|
||||
|
||||
vjps.push_back(reshape(g, primals[1].shape(), stream()));
|
||||
} else {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto base = zeros_like(primals[1], stream());
|
||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g =
|
||||
gather_mm(at, cotan, lhs_indices, std::nullopt, sorted, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
}
|
||||
auto shape = b.shape();
|
||||
shape.pop_back();
|
||||
shape.pop_back();
|
||||
vjps.push_back(swapaxes(
|
||||
gather_mm_grad(
|
||||
a,
|
||||
cotan,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
sorted,
|
||||
std::move(shape),
|
||||
stream()),
|
||||
-1,
|
||||
-2,
|
||||
stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[GatherMM] Cannot calculate VJP with respect to indices.");
|
||||
|
||||
@@ -8,6 +8,9 @@ cuda_skip = {
|
||||
# Gather matmul NYI
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_gather_mm_sorted",
|
||||
# Segmented matmul NYI
|
||||
"TestBlas.test_segmented_mm",
|
||||
# Scan NYI
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
|
||||
@@ -1207,38 +1207,38 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
(10, 10, 1000),
|
||||
(1000, 1000, 1000),
|
||||
]
|
||||
segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
||||
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
||||
|
||||
for M, N, K in shapes:
|
||||
for s in segments:
|
||||
for s in all_segments:
|
||||
segments = []
|
||||
for i in range(len(s) - 1):
|
||||
segments.append([s[i], s[i + 1]])
|
||||
segments = mx.array(segments)
|
||||
segments = mx.maximum(K - 1, segments.astype(mx.uint32))
|
||||
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
|
||||
a = mx.random.normal((M, K))
|
||||
b = mx.random.normal((K, N))
|
||||
c1 = segmented_mm_ref(a, b, segments)
|
||||
c2 = mx.segmented_mm(a, b, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2))
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((K, M))
|
||||
b = mx.random.normal((K, N))
|
||||
c1 = segmented_mm_ref(a.T, b, segments)
|
||||
c2 = mx.segmented_mm(a.T, b, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2))
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((M, K))
|
||||
b = mx.random.normal((N, K))
|
||||
c1 = segmented_mm_ref(a, b.T, segments)
|
||||
c2 = mx.segmented_mm(a, b.T, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2))
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
a = mx.random.normal((K, M))
|
||||
b = mx.random.normal((N, K))
|
||||
c1 = segmented_mm_ref(a.T, b.T, segments)
|
||||
c2 = mx.segmented_mm(a.T, b.T, segments)
|
||||
self.assertTrue(mx.allclose(c1, c2))
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.ones((2, 10, 10))
|
||||
|
||||
Reference in New Issue
Block a user