mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Add vmap for SVD and inverse (#849)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
@@ -92,4 +93,12 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -144,4 +145,12 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1127,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{equal(a, b, stream())}, axes};
|
||||
return {{equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Equal::vjp(
|
||||
@@ -1468,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{greater(a, b, stream())}, axes};
|
||||
return {{greater(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Greater::vjp(
|
||||
@@ -1495,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{greater_equal(a, b, stream())}, axes};
|
||||
return {{greater_equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> GreaterEqual::vjp(
|
||||
@@ -1522,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{less(a, b, stream())}, axes};
|
||||
return {{less(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Less::vjp(
|
||||
@@ -1549,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{less_equal(a, b, stream())}, axes};
|
||||
return {{less_equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> LessEqual::vjp(
|
||||
|
@@ -1929,6 +1929,7 @@ class SVD : public Primitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(SVD)
|
||||
|
||||
private:
|
||||
@@ -1943,6 +1944,7 @@ class Inverse : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Inverse)
|
||||
|
||||
private:
|
||||
|
@@ -655,6 +655,7 @@ std::vector<array> vmap_replace(
|
||||
}
|
||||
|
||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||
|
||||
// For each primitive's outputs add its id, the vout id and the vax
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < v_outputs.size(); ++i) {
|
||||
|
Reference in New Issue
Block a user