mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
vmap matmul and admm (#836)
This commit is contained in:
@@ -110,7 +110,11 @@ std::vector<array> Primitive::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's jvp not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::jvp] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<array> Primitive::vjp(
|
||||
@@ -118,13 +122,21 @@ std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
throw std::invalid_argument("Primitive's vjp not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vip] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vmap] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
@@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
|
||||
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||
};
|
||||
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||
auto c = maybe_move_ax(inputs[2], axes[2]);
|
||||
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
|
||||
}
|
||||
|
||||
bool Arange::is_equivalent(const Primitive& other) const {
|
||||
const Arange& a_other = static_cast<const Arange&>(other);
|
||||
return (
|
||||
@@ -1772,6 +1796,17 @@ std::vector<array> Matmul::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||
};
|
||||
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||
return {{matmul(a, b, stream())}, {0}};
|
||||
}
|
||||
|
||||
std::vector<array> Maximum::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(AddMM)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
@@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Matmul)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
};
|
||||
|
Reference in New Issue
Block a user