vmap matmul and admm (#836)

This commit is contained in:
Awni Hannun
2024-03-14 14:38:22 -07:00
committed by GitHub
parent 63ab0ab580
commit 19ec023256
3 changed files with 79 additions and 3 deletions

View File

@@ -110,7 +110,11 @@ std::vector<array> Primitive::jvp(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's jvp not implemented.");
std::ostringstream msg;
msg << "[Primitive::jvp] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<array> Primitive::vjp(
@@ -118,13 +122,21 @@ std::vector<array> Primitive::vjp(
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
throw std::invalid_argument("Primitive's vjp not implemented.");
std::ostringstream msg;
msg << "[Primitive::vip] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's vmap not implemented.");
std::ostringstream msg;
msg << "[Primitive::vmap] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<std::vector<int>> Primitive::output_shapes(
@@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
}
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
auto c = maybe_move_ax(inputs[2], axes[2]);
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
}
bool Arange::is_equivalent(const Primitive& other) const {
const Arange& a_other = static_cast<const Arange&>(other);
return (
@@ -1772,6 +1796,17 @@ std::vector<array> Matmul::vjp(
return vjps;
}
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
return {{matmul(a, b, stream())}, {0}};
}
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
@@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
};