Split multi output (#461)

* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
This commit is contained in:
Angelos Katharopoulos 2024-01-16 13:33:55 -08:00 committed by GitHub
parent 4e290d282f
commit d8fabaa12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 202 additions and 5 deletions

View File

@ -158,7 +158,26 @@ array::ArrayDesc::ArrayDesc(
}
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
// Iterate using split
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
splits = split(arr, arr.shape(0));
for (auto& arr_i : splits) {
arr_i = squeeze(arr_i, 0);
}
}
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
if (idx >= 0 && idx < splits.size()) {
return splits[idx];
}
auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();

View File

@ -127,11 +127,7 @@ class array {
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
explicit ArrayIterator(const array& arr, int idx = 0);
reference operator*() const;
@ -155,6 +151,7 @@ class array {
private:
const array& arr;
int idx;
std::vector<array> splits;
};
ArrayIterator begin() const {

View File

@ -60,6 +60,7 @@ DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)

View File

@ -88,6 +88,7 @@ DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)

View File

@ -588,6 +588,58 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@ -727,6 +727,12 @@ void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}

View File

@ -80,6 +80,7 @@ NO_GPU(Sinh)
NO_GPU(Slice)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(StopGradient)

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
@ -573,6 +574,29 @@ std::vector<array> split(
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (indices.empty()) {
return {a};
}
if (indices.size() < 10 &&
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
indices[0] > 0 && indices.back() < a.shape(ax)) {
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape());
shapes[0][ax] = indices[0];
for (int i = 1; i < indices.size(); i++) {
shapes[i][ax] = indices[i] - indices[i - 1];
}
shapes.back()[ax] = a.shape(ax) - indices.back();
return array::make_arrays(
shapes,
dtypes,
std::make_shared<Split>(to_stream(s), indices, ax),
{a});
}
std::vector<array> res;
auto out_shape = a.shape();
auto start_indices = std::vector<int>(a.ndim(), 0);

View File

@ -2493,6 +2493,32 @@ bool Sort::is_equivalent(const Primitive& other) const {
return axis_ == r_other.axis_;
}
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
}
std::vector<array> Split::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return {concatenate(cotangents, axis_, stream())};
}
std::vector<array> Split::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return split(tangents[0], indices_, axis_, stream());
}
bool Split::is_equivalent(const Primitive& other) const {
const Split& s_other = static_cast<const Split&>(other);
return axis_ == s_other.axis_ && indices_ == s_other.indices_;
}
std::vector<array> Square::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1421,6 +1421,28 @@ class Sort : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Split : public Primitive {
public:
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
: Primitive(stream), indices_(indices), axis_(axis){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Split)
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::vector<int> indices_;
int axis_;
};
class Square : public UnaryPrimitive {
public:
explicit Square(Stream stream) : UnaryPrimitive(stream){};

View File

@ -339,6 +339,25 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
def test_split_against_slice(self):
def f_split(x):
a, _, b = x.split(3, -1)
return (a * b).sum()
def f_slice(x):
step = x.shape[-1] // 3
a = x[..., :step]
b = x[..., -step:]
return (a * b).sum()
x = mx.random.uniform(shape=(100, 300))
mx.eval(x)
df1 = mx.grad(f_split)
df2 = mx.grad(f_slice)
self.assertTrue(mx.allclose(df1(x), df2(x)))
def test_vjp_types(self):
def fun(x):
return x

View File

@ -922,6 +922,35 @@ TEST_CASE("test concatenate grads") {
array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item<bool>());
}
TEST_CASE("test split grads") {
array x = arange(6, float32);
eval(x);
{
auto fn = [](const array& x) {
auto parts = split(x, 3);
return parts[0] * parts[1] + parts[2];
};
auto out = vjp(fn, {x}, {ones({2})}).second;
CHECK_EQ(out.size(), 6);
CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f}))
.item<bool>());
}
{
auto fn = [](const array& x) {
auto parts = split(x, 3);
return parts[0] * parts[2];
};
auto out = vjp(fn, {x}, {ones({2})}).second;
CHECK_EQ(out.size(), 6);
CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f}))
.item<bool>());
}
}
TEST_CASE("test comparison grads") {
auto x = ones({3, 1});
auto y = zeros({1, 3});