mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
matmul jvps (#1772)
This commit is contained in:
parent
f288db8d34
commit
0c259961ac
@ -246,6 +246,36 @@ std::vector<array> AddMM::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> AddMM::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
std::vector<array> jvp;
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto arg = argnums[i];
|
||||
if (arg == 0) {
|
||||
if (jvp.empty()) {
|
||||
jvp.push_back(matmul(tangents[i], primals[1], stream()));
|
||||
} else {
|
||||
jvp[0] = addmm(jvp[0], tangents[i], primals[1], 1.0f, 1.0f, stream());
|
||||
}
|
||||
} else if (arg == 1) {
|
||||
if (jvp.empty()) {
|
||||
jvp.push_back(matmul(primals[0], tangents[i], stream()));
|
||||
} else {
|
||||
jvp[0] = addmm(jvp[0], primals[0], tangents[i], 1.0f, 1.0f, stream());
|
||||
}
|
||||
} else {
|
||||
if (jvp.empty()) {
|
||||
jvp.push_back(tangents[i]);
|
||||
} else {
|
||||
jvp[0] = add(jvp[0], tangents[i], stream());
|
||||
}
|
||||
}
|
||||
}
|
||||
return jvp;
|
||||
}
|
||||
|
||||
bool AddMM::is_equivalent(const Primitive& other) const {
|
||||
const AddMM& a_other = static_cast<const AddMM&>(other);
|
||||
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
||||
@ -2439,6 +2469,26 @@ std::vector<array> Matmul::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> Matmul::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
std::vector<array> jvp;
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto arg = argnums[i];
|
||||
if (arg == 0 && i == 0) {
|
||||
jvp.push_back(matmul(tangents[0], primals[1], stream()));
|
||||
} else if (arg == 0 && i == 1) {
|
||||
jvp[0] = addmm(jvp[0], tangents[1], primals[1], 1.0f, 1.0f, stream());
|
||||
} else if (i == 0) {
|
||||
jvp.push_back(matmul(primals[0], tangents[0], stream()));
|
||||
} else if (i == 1) {
|
||||
jvp[0] = addmm(jvp[0], primals[0], tangents[1], 1.0f, 1.0f, stream());
|
||||
}
|
||||
}
|
||||
return jvp;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@ -2833,7 +2883,7 @@ std::pair<std::vector<array>, std::vector<int>> Power::vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("QuantizedMatmul::vmap NYI");
|
||||
throw std::runtime_error("[QuantizedMatmul::vmap] NYI");
|
||||
}
|
||||
|
||||
std::vector<array> QuantizedMatmul::vjp(
|
||||
@ -2861,7 +2911,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else {
|
||||
throw std::runtime_error(
|
||||
"QuantizedMatmul::vjp no gradient wrt the quantized matrix yet.");
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized matrix yet.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
@ -2871,7 +2921,19 @@ std::vector<array> QuantizedMatmul::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
throw std::runtime_error("QuantizedMatmul::jvp NYI");
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::runtime_error(
|
||||
"[QuantizedMatmul::jvp] No JVP wrt the quantized matrix yet.");
|
||||
}
|
||||
return {quantized_matmul(
|
||||
tangents[0],
|
||||
primals[1],
|
||||
primals[2],
|
||||
primals[3],
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
stream())};
|
||||
}
|
||||
|
||||
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||
|
@ -193,12 +193,7 @@ class AddMM : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(AddMM)
|
||||
|
||||
@ -1459,12 +1454,7 @@ class Matmul : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Matmul)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(grads[0].dtype, mx.float32)
|
||||
self.assertEqual(grads[1].dtype, mx.float16)
|
||||
|
||||
def test_matmul_jvps(self):
|
||||
a = mx.random.uniform(shape=(4, 4))
|
||||
b = mx.random.uniform(shape=(4, 4))
|
||||
c = mx.random.uniform(shape=(4, 4))
|
||||
d = mx.random.uniform(shape=(4, 4))
|
||||
|
||||
_, tangent = mx.jvp(lambda a: a @ b, (a,), (c,))
|
||||
self.assertTrue(mx.allclose(tangent[0], c @ b))
|
||||
|
||||
_, tangent = mx.jvp(lambda b: a @ b, (b,), (d,))
|
||||
self.assertTrue(mx.allclose(tangent[0], a @ d))
|
||||
|
||||
_, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d))
|
||||
self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b))
|
||||
|
||||
x = mx.random.uniform(shape=(4, 4))
|
||||
y = mx.random.uniform(shape=(4, 4))
|
||||
z = mx.random.uniform(shape=(4, 4))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z))
|
||||
_, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z))
|
||||
_, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z))
|
||||
_, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,))
|
||||
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -34,8 +34,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[8, 32, 33, 64], # M
|
||||
[512, 1024], # N
|
||||
[512, 1024], # K
|
||||
[128, 256], # N
|
||||
[128, 256], # K
|
||||
[True, False], # transposed
|
||||
)
|
||||
for group_size, bits, M, N, K, transposed in tests:
|
||||
@ -86,6 +86,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
||||
|
||||
def test_qmm_jvp(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
||||
bits = 8
|
||||
group_size = 64
|
||||
M = 64
|
||||
N = 128
|
||||
K = 128
|
||||
|
||||
x = mx.random.normal(shape=(2, M, K), key=k1)
|
||||
x_tan = mx.ones(shape=(2, M, N))
|
||||
|
||||
transposes = [True, False]
|
||||
for transposed in transposes:
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
|
||||
def fn(x):
|
||||
return mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
|
||||
_, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,))
|
||||
|
||||
expected_out = mx.quantized_matmul(
|
||||
x_tan, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
self.assertTrue(mx.allclose(jvp_out[0], expected_out))
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@ -117,8 +147,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[512, 1024, 67], # M
|
||||
[64, 128, 512, 1024], # N
|
||||
[256, 512, 67], # M
|
||||
[64, 128], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
@ -144,8 +174,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 3, 4, 6, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024, 67], # N
|
||||
[128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
|
Loading…
Reference in New Issue
Block a user