mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
vmap matmul and admm (#836)
This commit is contained in:
parent
63ab0ab580
commit
19ec023256
@ -110,7 +110,11 @@ std::vector<array> Primitive::jvp(
|
|||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
throw std::invalid_argument("Primitive's jvp not implemented.");
|
std::ostringstream msg;
|
||||||
|
msg << "[Primitive::jvp] Not implemented for ";
|
||||||
|
print(msg);
|
||||||
|
msg << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<array> Primitive::vjp(
|
std::vector<array> Primitive::vjp(
|
||||||
@ -118,13 +122,21 @@ std::vector<array> Primitive::vjp(
|
|||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<int>&,
|
const std::vector<int>&,
|
||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
throw std::invalid_argument("Primitive's vjp not implemented.");
|
std::ostringstream msg;
|
||||||
|
msg << "[Primitive::vip] Not implemented for ";
|
||||||
|
print(msg);
|
||||||
|
msg << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
};
|
};
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<int>&) {
|
const std::vector<int>&) {
|
||||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
std::ostringstream msg;
|
||||||
|
msg << "[Primitive::vmap] Not implemented for ";
|
||||||
|
print(msg);
|
||||||
|
msg << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||||
@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
|
|||||||
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||||
|
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||||
|
};
|
||||||
|
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||||
|
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||||
|
auto c = maybe_move_ax(inputs[2], axes[2]);
|
||||||
|
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
bool Arange::is_equivalent(const Primitive& other) const {
|
bool Arange::is_equivalent(const Primitive& other) const {
|
||||||
const Arange& a_other = static_cast<const Arange&>(other);
|
const Arange& a_other = static_cast<const Arange&>(other);
|
||||||
return (
|
return (
|
||||||
@ -1772,6 +1796,17 @@ std::vector<array> Matmul::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||||
|
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||||
|
};
|
||||||
|
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||||
|
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||||
|
return {{matmul(a, b, stream())}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Maximum::vjp(
|
std::vector<array> Maximum::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(AddMM)
|
DEFINE_PRINT(AddMM)
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Matmul)
|
DEFINE_PRINT(Matmul)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
};
|
};
|
||||||
|
@ -275,6 +275,45 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
|
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
|
||||||
|
|
||||||
|
def test_vmap_matmul(self):
|
||||||
|
a = mx.random.uniform(shape=(2, 3, 4))
|
||||||
|
b = mx.random.uniform(shape=(4, 3))
|
||||||
|
|
||||||
|
# matmul
|
||||||
|
out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b)
|
||||||
|
self.assertTrue(mx.allclose(out, a @ b))
|
||||||
|
|
||||||
|
# addmm
|
||||||
|
c = mx.random.uniform(shape=(3,))
|
||||||
|
out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b)
|
||||||
|
self.assertTrue(mx.allclose(out, mx.addmm(c, a, b)))
|
||||||
|
|
||||||
|
b = mx.random.uniform(shape=(4, 2))
|
||||||
|
|
||||||
|
# matmul
|
||||||
|
out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b)
|
||||||
|
expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# addmm
|
||||||
|
c = mx.random.uniform(shape=(2,))
|
||||||
|
out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b)
|
||||||
|
self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b)))
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(2, 3, 4))
|
||||||
|
b = mx.random.uniform(shape=(4, 2, 3))
|
||||||
|
|
||||||
|
# matmul
|
||||||
|
out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b)
|
||||||
|
expected = a @ mx.moveaxis(b, 1, 0)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
# addmm
|
||||||
|
c = mx.random.uniform(shape=(3, 3, 2))
|
||||||
|
out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b)
|
||||||
|
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user