mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add vmap for SVD and inverse (#849)
This commit is contained in:
parent
53e6a9367c
commit
105d236889
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
@ -92,4 +93,12 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
|||||||
inverse_impl(inputs[0], output);
|
inverse_impl(inputs[0], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
|
return {{linalg::inv(a, stream())}, {ax}};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack_helper.h"
|
#include "mlx/backend/common/lapack_helper.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -144,4 +145,12 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|||||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
|
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1127,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
return {{equal(a, b, stream())}, axes};
|
return {{equal(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Equal::vjp(
|
std::vector<array> Equal::vjp(
|
||||||
@ -1468,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
return {{greater(a, b, stream())}, axes};
|
return {{greater(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Greater::vjp(
|
std::vector<array> Greater::vjp(
|
||||||
@ -1495,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
return {{greater_equal(a, b, stream())}, axes};
|
return {{greater_equal(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> GreaterEqual::vjp(
|
std::vector<array> GreaterEqual::vjp(
|
||||||
@ -1522,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
return {{less(a, b, stream())}, axes};
|
return {{less(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Less::vjp(
|
std::vector<array> Less::vjp(
|
||||||
@ -1549,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
return {{less_equal(a, b, stream())}, axes};
|
return {{less_equal(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> LessEqual::vjp(
|
std::vector<array> LessEqual::vjp(
|
||||||
|
@ -1929,6 +1929,7 @@ class SVD : public Primitive {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(SVD)
|
DEFINE_PRINT(SVD)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1943,6 +1944,7 @@ class Inverse : public UnaryPrimitive {
|
|||||||
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Inverse)
|
DEFINE_PRINT(Inverse)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -655,6 +655,7 @@ std::vector<array> vmap_replace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||||
|
|
||||||
// For each primitive's outputs add its id, the vout id and the vax
|
// For each primitive's outputs add its id, the vout id and the vax
|
||||||
auto outputs = a.outputs();
|
auto outputs = a.outputs();
|
||||||
for (int i = 0; i < v_outputs.size(); ++i) {
|
for (int i = 0; i < v_outputs.size(); ++i) {
|
||||||
|
@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
|
expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0))
|
||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def test_vmap_svd(self):
|
||||||
|
a = mx.random.uniform(shape=(3, 4, 2))
|
||||||
|
|
||||||
|
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
|
||||||
|
|
||||||
|
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||||
|
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
|
||||||
|
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
||||||
|
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
||||||
|
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
||||||
|
|
||||||
|
for i in range(a.shape[0]):
|
||||||
|
M = a[i]
|
||||||
|
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vmap over the second axis.
|
||||||
|
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
|
||||||
|
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
||||||
|
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
||||||
|
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
||||||
|
|
||||||
|
for i in range(a.shape[1]):
|
||||||
|
M = a[:, i, :]
|
||||||
|
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_vmap_inverse(self):
|
||||||
|
a = mx.random.uniform(shape=(3, 4, 4))
|
||||||
|
|
||||||
|
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)
|
||||||
|
|
||||||
|
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||||
|
invs = mx.vmap(cpu_inv, in_axes=(0,))(a)
|
||||||
|
|
||||||
|
for i in range(a.shape[0]):
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
|
||||||
|
)
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(4, 3, 4))
|
||||||
|
|
||||||
|
# Without vmapping, each input matrix is not square.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.eval(cpu_inv(a))
|
||||||
|
|
||||||
|
# Vmap over the second axis.
|
||||||
|
invs = mx.vmap(cpu_inv, in_axes=(1,))(a)
|
||||||
|
|
||||||
|
for i in range(a.shape[1]):
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") {
|
|||||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test vmap SVD") {
|
||||||
|
auto fun = [](std::vector<array> inputs) {
|
||||||
|
return linalg::svd(inputs.at(0), Device::cpu);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
||||||
|
|
||||||
|
// vmap over the second axis.
|
||||||
|
{
|
||||||
|
auto out = vmap(fun, /* in_axes = */ {1})({a});
|
||||||
|
const auto& U = out.at(0);
|
||||||
|
const auto& S = out.at(1);
|
||||||
|
const auto& Vt = out.at(2);
|
||||||
|
|
||||||
|
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
|
||||||
|
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
|
||||||
|
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// vmap over the third axis.
|
||||||
|
{
|
||||||
|
auto out = vmap(fun, /* in_axes = */ {2})({a});
|
||||||
|
const auto& U = out.at(0);
|
||||||
|
const auto& S = out.at(1);
|
||||||
|
const auto& Vt = out.at(2);
|
||||||
|
|
||||||
|
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
|
||||||
|
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
|
||||||
|
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user