vmap matmul and admm (#836)

This commit is contained in:
Awni Hannun 2024-03-14 14:38:22 -07:00 committed by GitHub
parent 63ab0ab580
commit 19ec023256
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 3 deletions

View File

@ -110,7 +110,11 @@ std::vector<array> Primitive::jvp(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's jvp not implemented.");
std::ostringstream msg;
msg << "[Primitive::jvp] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<array> Primitive::vjp(
@ -118,13 +122,21 @@ std::vector<array> Primitive::vjp(
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
throw std::invalid_argument("Primitive's vjp not implemented.");
std::ostringstream msg;
msg << "[Primitive::vip] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's vmap not implemented.");
std::ostringstream msg;
msg << "[Primitive::vmap] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<std::vector<int>> Primitive::output_shapes(
@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
}
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
auto c = maybe_move_ax(inputs[2], axes[2]);
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
}
bool Arange::is_equivalent(const Primitive& other) const {
const Arange& a_other = static_cast<const Arange&>(other);
return (
@ -1772,6 +1796,17 @@ std::vector<array> Matmul::vjp(
return vjps;
}
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
return {{matmul(a, b, stream())}, {0}};
}
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
};

View File

@ -275,6 +275,45 @@ class TestVmap(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
def test_vmap_matmul(self):
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 3))
# matmul
out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b)
self.assertTrue(mx.allclose(out, a @ b))
# addmm
c = mx.random.uniform(shape=(3,))
out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b)
self.assertTrue(mx.allclose(out, mx.addmm(c, a, b)))
b = mx.random.uniform(shape=(4, 2))
# matmul
out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b)
expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1)
self.assertTrue(mx.allclose(out, expected))
# addmm
c = mx.random.uniform(shape=(2,))
out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b)
self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b)))
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 2, 3))
# matmul
out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b)
expected = a @ mx.moveaxis(b, 1, 0)
self.assertTrue(mx.allclose(out, expected))
# addmm
c = mx.random.uniform(shape=(3, 3, 2))
out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b)
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()