Reduce vmap + some fixes (#601)

This commit is contained in:
Awni Hannun
2024-02-01 11:30:28 -08:00
committed by GitHub
parent 601c6d6aa8
commit e88e474fd1
5 changed files with 161 additions and 33 deletions

View File

@@ -17,8 +17,7 @@ namespace {
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape,
bool keepdims) {
const std::vector<int>& shape) {
std::set<int> axes_set;
auto ndim = shape.size();
for (auto ax : axes) {
@@ -38,7 +37,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
} else if (keepdims) {
} else {
out_shape.push_back(1);
}
}
@@ -1217,13 +1216,16 @@ array all(
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array all(
@@ -1248,13 +1250,16 @@ array any(
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array any(
@@ -1279,14 +1284,17 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
return array(
auto out = array(
out_shape,
out_type,
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array sum(
@@ -1374,13 +1382,16 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array prod(
@@ -1408,13 +1419,16 @@ array max(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array max(
@@ -1442,13 +1456,16 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array min(
@@ -1477,14 +1494,17 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes] =
compute_reduce_shape({axis}, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
@@ -1505,14 +1525,17 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes] =
compute_reduce_shape({axis}, a.shape(), keepdims);
return array(
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
/** Returns a sorted copy of the flattened array. */

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -361,6 +360,20 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;
}
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
int reduce_ax = axis_ + (axis_ >= axes[0]);
auto& in = inputs[0];
std::vector<array> out;
if (reduce_type_ == ArgReduce::ArgMin) {
out.push_back(argmin(in, reduce_ax, true, stream()));
} else {
out.push_back(argmax(in, reduce_ax, true, stream()));
}
return {out, axes};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2153,7 +2166,36 @@ std::vector<array> Reduce::vjp(
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Reduce::vmap not yet implemented.");
auto ax = axes[0];
auto reduce_axes = axes_;
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
}
}
auto& in = inputs[0];
std::vector<array> out;
switch (reduce_type_) {
case Reduce::And:
out.push_back(all(in, reduce_axes, true, stream()));
break;
case Reduce::Or:
out.push_back(any(in, reduce_axes, true, stream()));
break;
case Reduce::Sum:
out.push_back(sum(in, reduce_axes, true, stream()));
break;
case Reduce::Prod:
out.push_back(prod(in, reduce_axes, true, stream()));
break;
case Reduce::Min:
out.push_back(min(in, reduce_axes, true, stream()));
break;
case Reduce::Max:
out.push_back(max(in, reduce_axes, true, stream()));
break;
}
return {out, axes};
}
bool Reduce::is_equivalent(const Primitive& other) const {

View File

@@ -341,6 +341,7 @@ class ArgReduce : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;

View File

@@ -548,9 +548,8 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
"[vmap] The number of in axes must match the number of inputs.");
}
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
// Some error checking and get the vmap axis size
size_t vmap_ax_size;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
if (inputs[i].ndim() == 0) {
@@ -563,7 +562,26 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
<< inputs[i].ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
vmap_ax_size = inputs[i].shape(in_axes[i]);
}
}
// Check that all vmapped axes have the same size
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
std::ostringstream msg;
msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
<< vmap_ax_size << ".";
throw std::invalid_argument(msg.str());
}
}
}
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
std::vector<int> shape = inputs[i].shape();
shape.erase(shape.begin() + in_axes[i]);
array in(shape, inputs[i].dtype(), nullptr, {});