mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Split multi output (#461)
* Multi-output split primitive * Add the multi-output split to the ArrayIterator * Add some grad tests for split
This commit is contained in:
parent
4e290d282f
commit
d8fabaa12b
@ -158,7 +158,26 @@ array::ArrayDesc::ArrayDesc(
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
: arr(arr), idx(idx) {
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
|
||||
// Iterate using split
|
||||
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
|
||||
splits = split(arr, arr.shape(0));
|
||||
for (auto& arr_i : splits) {
|
||||
arr_i = squeeze(arr_i, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
if (idx >= 0 && idx < splits.size()) {
|
||||
return splits[idx];
|
||||
}
|
||||
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
|
@ -127,11 +127,7 @@ class array {
|
||||
using value_type = const array;
|
||||
using reference = value_type;
|
||||
|
||||
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
}
|
||||
explicit ArrayIterator(const array& arr, int idx = 0);
|
||||
|
||||
reference operator*() const;
|
||||
|
||||
@ -155,6 +151,7 @@ class array {
|
||||
private:
|
||||
const array& arr;
|
||||
int idx;
|
||||
std::vector<array> splits;
|
||||
};
|
||||
|
||||
ArrayIterator begin() const {
|
||||
|
@ -60,6 +60,7 @@ DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
|
@ -88,6 +88,7 @@ DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
|
@ -588,6 +588,58 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -727,6 +727,12 @@ void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sinh");
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "square");
|
||||
}
|
||||
|
@ -80,6 +80,7 @@ NO_GPU(Sinh)
|
||||
NO_GPU(Slice)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU_MULTI(Split)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(StopGradient)
|
||||
|
24
mlx/ops.cpp
24
mlx/ops.cpp
@ -1,4 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
@ -573,6 +574,29 @@ std::vector<array> split(
|
||||
<< " for array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (indices.empty()) {
|
||||
return {a};
|
||||
}
|
||||
|
||||
if (indices.size() < 10 &&
|
||||
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
|
||||
indices[0] > 0 && indices.back() < a.shape(ax)) {
|
||||
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
|
||||
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape());
|
||||
shapes[0][ax] = indices[0];
|
||||
for (int i = 1; i < indices.size(); i++) {
|
||||
shapes[i][ax] = indices[i] - indices[i - 1];
|
||||
}
|
||||
shapes.back()[ax] = a.shape(ax) - indices.back();
|
||||
|
||||
return array::make_arrays(
|
||||
shapes,
|
||||
dtypes,
|
||||
std::make_shared<Split>(to_stream(s), indices, ax),
|
||||
{a});
|
||||
}
|
||||
|
||||
std::vector<array> res;
|
||||
auto out_shape = a.shape();
|
||||
auto start_indices = std::vector<int>(a.ndim(), 0);
|
||||
|
@ -2493,6 +2493,32 @@ bool Sort::is_equivalent(const Primitive& other) const {
|
||||
return axis_ == r_other.axis_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {
|
||||
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Split::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {concatenate(cotangents, axis_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Split::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return split(tangents[0], indices_, axis_, stream());
|
||||
}
|
||||
|
||||
bool Split::is_equivalent(const Primitive& other) const {
|
||||
const Split& s_other = static_cast<const Split&>(other);
|
||||
return axis_ == s_other.axis_ && indices_ == s_other.indices_;
|
||||
}
|
||||
|
||||
std::vector<array> Square::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -1421,6 +1421,28 @@ class Sort : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Split : public Primitive {
|
||||
public:
|
||||
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
|
||||
: Primitive(stream), indices_(indices), axis_(axis){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Split)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
std::vector<int> indices_;
|
||||
int axis_;
|
||||
};
|
||||
|
||||
class Square : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Square(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
@ -339,6 +339,25 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||
|
||||
def test_split_against_slice(self):
|
||||
def f_split(x):
|
||||
a, _, b = x.split(3, -1)
|
||||
return (a * b).sum()
|
||||
|
||||
def f_slice(x):
|
||||
step = x.shape[-1] // 3
|
||||
a = x[..., :step]
|
||||
b = x[..., -step:]
|
||||
return (a * b).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(100, 300))
|
||||
mx.eval(x)
|
||||
|
||||
df1 = mx.grad(f_split)
|
||||
df2 = mx.grad(f_slice)
|
||||
|
||||
self.assertTrue(mx.allclose(df1(x), df2(x)))
|
||||
|
||||
def test_vjp_types(self):
|
||||
def fun(x):
|
||||
return x
|
||||
|
@ -922,6 +922,35 @@ TEST_CASE("test concatenate grads") {
|
||||
array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test split grads") {
|
||||
array x = arange(6, float32);
|
||||
eval(x);
|
||||
|
||||
{
|
||||
auto fn = [](const array& x) {
|
||||
auto parts = split(x, 3);
|
||||
return parts[0] * parts[1] + parts[2];
|
||||
};
|
||||
auto out = vjp(fn, {x}, {ones({2})}).second;
|
||||
|
||||
CHECK_EQ(out.size(), 6);
|
||||
CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fn = [](const array& x) {
|
||||
auto parts = split(x, 3);
|
||||
return parts[0] * parts[2];
|
||||
};
|
||||
auto out = vjp(fn, {x}, {ones({2})}).second;
|
||||
|
||||
CHECK_EQ(out.size(), 6);
|
||||
CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f}))
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test comparison grads") {
|
||||
auto x = ones({3, 1});
|
||||
auto y = zeros({1, 3});
|
||||
|
Loading…
Reference in New Issue
Block a user