Add vmap for SVD and inverse (#849)

This commit is contained in:
nicolov 2024-03-21 21:18:27 +01:00 committed by GitHub
parent 53e6a9367c
commit 105d236889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 116 additions and 5 deletions

View File

@ -2,6 +2,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
@ -92,4 +93,12 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
inverse_impl(inputs[0], output); inverse_impl(inputs[0], output);
} }
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::inv(a, stream())}, {ax}};
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,6 +3,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h" #include "mlx/backend/common/lapack_helper.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@ -144,4 +145,12 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
} }
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1127,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{equal(a, b, stream())}, axes}; return {{equal(a, b, stream())}, {to_ax}};
} }
std::vector<array> Equal::vjp( std::vector<array> Equal::vjp(
@ -1468,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{greater(a, b, stream())}, axes}; return {{greater(a, b, stream())}, {to_ax}};
} }
std::vector<array> Greater::vjp( std::vector<array> Greater::vjp(
@ -1495,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{greater_equal(a, b, stream())}, axes}; return {{greater_equal(a, b, stream())}, {to_ax}};
} }
std::vector<array> GreaterEqual::vjp( std::vector<array> GreaterEqual::vjp(
@ -1522,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{less(a, b, stream())}, axes}; return {{less(a, b, stream())}, {to_ax}};
} }
std::vector<array> Less::vjp( std::vector<array> Less::vjp(
@ -1549,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{less_equal(a, b, stream())}, axes}; return {{less_equal(a, b, stream())}, {to_ax}};
} }
std::vector<array> LessEqual::vjp( std::vector<array> LessEqual::vjp(

View File

@ -1929,6 +1929,7 @@ class SVD : public Primitive {
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
DEFINE_VMAP()
DEFINE_PRINT(SVD) DEFINE_PRINT(SVD)
private: private:
@ -1943,6 +1944,7 @@ class Inverse : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& output) override; void eval_cpu(const std::vector<array>& inputs, array& output) override;
void eval_gpu(const std::vector<array>& inputs, array& output) override; void eval_gpu(const std::vector<array>& inputs, array& output) override;
DEFINE_VMAP()
DEFINE_PRINT(Inverse) DEFINE_PRINT(Inverse)
private: private:

View File

@ -655,6 +655,7 @@ std::vector<array> vmap_replace(
} }
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes); auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax // For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs(); auto outputs = a.outputs();
for (int i = 0; i < v_outputs.size(); ++i) { for (int i = 0; i < v_outputs.size(); ++i) {

View File

@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0)) expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
self.assertTrue(mx.allclose(out, expected)) self.assertTrue(mx.allclose(out, expected))
def test_vmap_svd(self):
a = mx.random.uniform(shape=(3, 4, 2))
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
for i in range(a.shape[0]):
M = a[i]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
# Vmap over the second axis.
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
for i in range(a.shape[1]):
M = a[:, i, :]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
def test_vmap_inverse(self):
a = mx.random.uniform(shape=(3, 4, 4))
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
invs = mx.vmap(cpu_inv, in_axes=(0,))(a)
for i in range(a.shape[0]):
self.assertTrue(
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
)
a = mx.random.uniform(shape=(4, 3, 4))
# Without vmapping, each input matrix is not square.
with self.assertRaises(ValueError):
mx.eval(cpu_inv(a))
# Vmap over the second axis.
invs = mx.vmap(cpu_inv, in_axes=(1,))(a)
for i in range(a.shape[1]):
self.assertTrue(
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") {
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2}); CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
} }
} }
TEST_CASE("test vmap SVD") {
auto fun = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), Device::cpu);
};
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
// vmap over the second axis.
{
auto out = vmap(fun, /* in_axes = */ {1})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
}
// vmap over the third axis.
{
auto out = vmap(fun, /* in_axes = */ {2})({a});
const auto& U = out.at(0);
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
}
}