Multi output primitives (#330)

* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-01-08 16:39:08 -08:00 committed by GitHub
parent f45f70f133
commit f099ebe535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 2313 additions and 1039 deletions

View File

@ -233,6 +233,20 @@ void time_gather_scatter() {
TIME(single_element_add);
}
void time_divmod() {
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
@ -246,4 +260,5 @@ int main() {
time_matmul();
time_reductions();
time_gather_scatter();
time_divmod();
}

View File

@ -36,6 +36,7 @@ Operations
cosh
dequantize
divide
divmod
equal
erf
erfinv

View File

@ -39,7 +39,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
@ -47,6 +47,23 @@ array::array(
std::move(primitive),
inputs)) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
}
return outputs;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
@ -66,6 +83,8 @@ array::array(
void array::detach() {
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->primitive = nullptr;
}
@ -127,7 +146,7 @@ array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
@ -139,10 +158,6 @@ array::ArrayDesc::ArrayDesc(
}
}
// Needed because the Primitive type used in array.h is incomplete and the
// compiler needs to see the call to the destructor after the type is complete.
array::ArrayDesc::~ArrayDesc() = default;
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape();

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
@ -174,7 +173,13 @@ class array {
array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@ -182,6 +187,11 @@ class array {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
}
/** A unique identifier for an arrays primitive. */
std::uintptr_t primitive_id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
}
struct Data {
allocator::Buffer buffer;
deleter_t d;
@ -219,12 +229,32 @@ class array {
return array_desc_->inputs;
};
/** A non-const reference to the array's inputs so that they can be used to
* edit the graph. */
std::vector<array>& editable_inputs() {
std::vector<array>& inputs() {
return array_desc_->inputs;
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
auto idx = array_desc_->position;
std::vector<array> outputs;
outputs.reserve(siblings().size() + 1);
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
};
/** Detach the array from the graph. */
void detach();
@ -299,7 +329,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::unique_ptr<Primitive> primitive{nullptr};
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@ -321,16 +351,19 @@ class array {
Flags flags;
std::vector<array> inputs;
// An array to keep track of the siblings from a multi-output
// primitive.
std::vector<array> siblings;
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
~ArrayDesc();
};
// The ArrayDesc contains the details of the materialized array including the

View File

@ -17,6 +17,12 @@
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
// Use the default implementation for the following primitives
@ -57,6 +63,7 @@ DEFAULT(Slice)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@ -6,6 +6,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@ -75,6 +76,61 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void DivMod::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
}
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];

View File

@ -73,6 +73,12 @@ struct UseDefaultBinaryOp {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op>
@ -89,6 +95,18 @@ struct DefaultVectorScalar {
a++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
};
template <typename T, typename U, typename Op>
@ -105,6 +123,18 @@ struct DefaultScalarVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
};
template <typename T, typename U, typename Op>
@ -121,6 +151,18 @@ struct DefaultVectorVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op>

View File

@ -0,0 +1,536 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out_a.size(); ++i) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
switch (out_a.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Ops... ops) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, ops...);
break;
case int8:
binary_op<int8_t>(a, b, outputs, ops...);
break;
case int16:
binary_op<int16_t>(a, b, outputs, ops...);
break;
case int32:
binary_op<int32_t>(a, b, outputs, ops...);
break;
case int64:
binary_op<int64_t>(a, b, outputs, ops...);
break;
case float16:
binary_op<float16_t>(a, b, outputs, ops...);
break;
case float32:
binary_op<float>(a, b, outputs, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, ops...);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@ -16,6 +16,12 @@
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
DEFAULT(Abs)
@ -89,6 +95,7 @@ DEFAULT(Subtract)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {

View File

@ -14,6 +14,7 @@ set(
"arange"
"arg_reduce"
"binary"
"binary_two"
"conv"
"copy"
"gemm"

View File

@ -0,0 +1,259 @@
// Copyright © 2023 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; }
template <> float operator()(float x, float y) { return trunc(x / y); }
template <> half operator()(half x, half y) { return trunc(x / y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
};
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[index]);
d[index] = Op2()(a[0], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[0]);
d[index] = Op2()(a[index], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[index]);
d[index] = Op2()(a[index], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op1()(a[a_idx], b[b_idx]);
d[index] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder)

View File

@ -4,7 +4,6 @@
#include <future>
#include <memory>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@ -54,7 +53,8 @@ std::function<void()> make_task(
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
@ -62,6 +62,9 @@ std::function<void()> make_task(
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
p->set_value();
scheduler::notify_task_completion(s);

View File

@ -19,6 +19,98 @@ namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
auto& out = outputs[0];
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
kname << "ss";
break;
case ScalarVector:
kname << "sv";
break;
case VectorScalar:
kname << "vs";
break;
case VectorVector:
kname << "vv";
break;
case General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
if (bopt == General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void binary_op(
const std::vector<array>& inputs,
array& out,
@ -364,6 +456,12 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}

View File

@ -2,6 +2,12 @@
#include "mlx/primitives.h"
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no GPU implementation."); \
}
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \
@ -81,5 +87,6 @@ NO_GPU(Subtract)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU_MULTI(DivMod)
} // namespace mlx::core

View File

@ -12,13 +12,11 @@
namespace mlx::core {
using OptionalArrayRef = std::optional<std::reference_wrapper<const array>>;
struct ArrayNames {
struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
std::string get_name(const array& x) {
auto it = names.find(x.id());
std::string get_name(uintptr_t id) {
auto it = names.find(id);
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
@ -29,45 +27,42 @@ struct ArrayNames {
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
names.insert({id, name});
return name;
}
return it->second;
}
std::string get_name(const array& x) {
return get_name(x.id());
}
};
void depth_first_traversal(
std::function<void(OptionalArrayRef, const array&, int)> callback,
std::function<void(array)> callback,
const std::vector<array>& outputs) {
std::function<void(OptionalArrayRef, const array&, int)> recurse;
std::function<void(const array&)> recurse;
std::unordered_set<std::uintptr_t> cache;
recurse = [&](OptionalArrayRef parent, const array& x, int input_index) {
recurse = [&](const array& x) {
auto id = x.id();
if (cache.find(id) != cache.end()) {
return;
}
cache.insert(id);
for (int i = 0; i < x.inputs().size(); i++) {
recurse(x, x.inputs()[i], i);
for (auto& s : x.siblings()) {
cache.insert(s.id());
}
callback(parent, x, input_index);
for (auto& in : x.inputs()) {
recurse(in);
}
callback(x);
};
for (auto x : outputs) {
recurse(std::nullopt, x, 0);
for (auto& o : outputs) {
recurse(o);
}
}
void depth_first_traversal(
std::function<void(const array&)> callback,
const std::vector<array>& outputs) {
depth_first_traversal(
[&callback](OptionalArrayRef p, const array& x, int input_index) {
callback(x);
},
outputs);
}
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
std::vector<array> tape;
std::vector<array> inputs;
@ -82,15 +77,11 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
},
outputs);
ArrayNames namer;
auto print_arr = [&namer, &os](const array& a) {
os << namer.get_name(a);
os << " [" << a.shape() << ", " << a.dtype() << "]";
};
auto print_arrs = [&](const std::vector<array>& arrs) {
NodeNamer namer;
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
for (auto& arr : arrs) {
print_arr(arr);
os << namer.get_name(arr);
os << " [" << arr.shape() << ", " << arr.dtype() << "]";
if (&arr != &arrs.back()) {
os << ", ";
}
@ -108,7 +99,7 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
os << " ";
print_arrs(arr.inputs());
os << " -> ";
print_arr(arr);
print_arrs(arr.outputs());
os << "\n";
}
}
@ -116,26 +107,47 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << "digraph {" << std::endl;
ArrayNames namer;
std::unordered_set<std::uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
std::unordered_set<std::uintptr_t> input_set;
NodeNamer namer;
depth_first_traversal(
[&namer, &os](auto parent, const array& x, int input_index) {
os << "{ ";
[&](const array& x) {
if (!x.has_primitive()) {
os << "rank=source; ";
input_set.insert(x.id());
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl;
return;
}
if (!parent) {
os << "rank=sink; ";
}
os << namer.get_name(x);
// Node for primitive
if (x.has_primitive()) {
os << "{ ";
os << namer.get_name(x.primitive_id());
os << " [label =\"";
x.primitive().print(os);
os << "\"]";
os << "\", shape=rectangle]";
os << "; }" << std::endl;
// Arrows to primitive's inputs
for (auto& a : x.inputs()) {
os << namer.get_name(x.primitive_id()) << " -> "
<< namer.get_name(a) << std::endl;
}
}
os << "; }" << std::endl;
for (auto c : x.inputs()) {
os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl;
// Point outputs to their primitive
for (auto& a : x.outputs()) {
os << "{ ";
if (output_set.find(a.id()) != output_set.end()) {
os << "rank=sink; ";
}
os << namer.get_name(a);
os << "; }" << std::endl;
if (x.has_primitive()) {
os << namer.get_name(a) << " -> "
<< namer.get_name(x.primitive_id()) << std::endl;
}
}
},
outputs);

View File

@ -1737,6 +1737,21 @@ array operator%(const array& a, const array& b) {
return remainder(a, b);
}
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_complex(dtype)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()},
std::make_unique<DivMod>(to_stream(s)),
inputs);
}
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =

View File

@ -721,6 +721,10 @@ array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Compute the element-wise quotient and remainder. */
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s = {});
/** Compute integer division. Equivalent to doing floor(a / x). */
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <future>
#include <map>
@ -26,6 +25,16 @@ namespace mlx::core {
int detail::InTracing::tracing_counter{0};
void simplify(const std::vector<array>& outputs) {
// Some notes about how this function works
//
// Step 1: Traverse the graph and build a tape. During the graph
// traversal we:
// - Build a map of inputs to their parents.
// - Record scalar inputs in a map in order to fuse them.
// Step 2: Process the tape. A node in the tape has inputs and outputs.
// - Scalar inputs are replaced with their canonical scalar
// - We check each inputs output nodes. Every output node that matches
// the current node gets fused into the current node.
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@ -54,7 +63,7 @@ void simplify(const std::vector<array>& outputs) {
return std::make_pair(v, a.dtype().val);
};
// DFS the graph to log the parents
// DFS the graph to build the tape, and log parents and scalars
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
@ -63,9 +72,16 @@ void simplify(const std::vector<array>& outputs) {
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
recurse(in);
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
tape.push(a);
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
@ -78,26 +94,33 @@ void simplify(const std::vector<array>& outputs) {
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
auto src_parents = parents_map.find(src.id());
if (src_parents == parents_map.end()) {
return;
}
auto& pairs = parents_map[dst.id()];
for (auto& parent : src_parents->second) {
parent.first.editable_inputs()[parent.second] = dst;
pairs.push_back(parent);
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Walk the graph
cache.clear();
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
if (a.primitive_id() == b.primitive_id()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
@ -117,14 +140,11 @@ void simplify(const std::vector<array>& outputs) {
return pa.is_equivalent(pb);
};
// Walk the graph
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
if (cache.find(arr.id()) != cache.end()) {
continue;
}
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
@ -134,28 +154,35 @@ void simplify(const std::vector<array>& outputs) {
}
}
// Check if we can fuse the parents of this array
auto parents = parents_map.find(arr.id());
if (parents != parents_map.end()) {
std::vector<bool> mask(parents->second.size(), false);
auto N = parents->second.size();
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
// Helper to check if we can fuse the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
cache.insert(src.id());
fuse(dst, src);
mask[j] = true;
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
mask[j] = true;
}
}
}
}
};
maybe_fuse_parents(arr);
for (auto& s : arr.siblings()) {
maybe_fuse_parents(s);
}
}
}
@ -177,11 +204,14 @@ void eval(const std::vector<array>& outputs) {
// stream, we need to manage the dependency.
if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.id(), std::shared_future<void>{}});
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
@ -191,17 +221,23 @@ void eval(const std::vector<array>& outputs) {
}
};
// We have to store the output primitive ids because the arrays are
// detached during eval and we need to use them for synchronization
// at the end of this function
std::vector<std::uintptr_t> output_primitive_ids;
for (auto& arr : outputs) {
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
recurse(arr);
// Insert a dependency for every output to synchronize
// with at the end.
if (!arr.is_evaled()) {
deps.insert({arr.id(), std::shared_future<void>{}});
if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) {
deps.insert({arr.primitive_id(), std::shared_future<void>{}});
output_primitive_ids.push_back(arr.primitive_id());
}
}
}
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
@ -215,13 +251,14 @@ void eval(const std::vector<array>& outputs) {
auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) {
if (auto it = deps.find(in.id()); it != deps.end()) {
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
arr_deps.push_back(it->second);
}
}
std::shared_ptr<std::promise<void>> p{nullptr};
if (auto it = deps.find(arr.id()); it != deps.end()) {
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
p = std::make_unique<std::promise<void>>();
ps.push_back(p);
it->second = p->get_future().share();
}
@ -234,15 +271,19 @@ void eval(const std::vector<array>& outputs) {
} else {
auto task = [arr,
stream,
arr_deps = std::move(arr_deps),
deps = std::move(arr_deps),
p = std::move(p)]() mutable {
for (auto& d : arr_deps) {
for (auto& d : deps) {
d.wait();
}
scheduler::notify_new_task(stream);
arr.primitive().eval_cpu(arr.inputs(), arr);
auto outputs = arr.outputs();
arr.primitive().eval_cpu(arr.inputs(), outputs);
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
if (p) {
p->set_value();
@ -252,10 +293,8 @@ void eval(const std::vector<array>& outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
for (auto& arr : outputs) {
if (auto it = deps.find(arr.id()); it != deps.end()) {
it->second.wait();
}
for (auto id : output_primitive_ids) {
deps[id].wait();
}
}
@ -301,8 +340,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
output_cotan_pairs.emplace_back(i, cotan_index++);
}
// Topologically sort the compute graph, record outputs
// in the tape if a gradient is needed.
// Topologically sort the compute graph, add graph nodes
// to the tape which need a gradient.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) {
@ -315,34 +354,41 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
std::function<void(array&)> recurse;
recurse = [&](auto& a) {
auto id = a.id();
a.set_tracer(false);
// Check if visited and add to cache if not
if (auto inserted = cache.insert(id); !inserted.second) {
if (auto inserted = cache.insert(a.id()); !inserted.second) {
return;
}
a.set_tracer(false);
for (auto s : a.siblings()) {
s.set_tracer(false);
cache.insert(s.id());
}
for (auto& input : a.editable_inputs()) {
for (auto& input : a.inputs()) {
recurse(input);
}
// Stop grad
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
return;
if (a.has_primitive()) {
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
return;
}
}
// Calculate gradient if any inputs require gradient
for (auto& input : a.inputs()) {
if (calc_grad.find(input.id()) != calc_grad.end()) {
tape.push_back(a);
calc_grad.insert(id);
calc_grad.insert(a.id());
for (auto& s : a.siblings()) {
calc_grad.insert(s.id());
}
break;
}
}
};
for (auto& out : outputs) {
for (auto out : outputs) {
recurse(out);
}
@ -363,14 +409,28 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
auto cotan_it = cotan_map.find(a.id());
if (cotan_it == cotan_map.end()) {
// Check if any of the array or its siblings have cotangents,
// if not, we can skip this primitive
auto outputs = a.outputs();
bool has_cotans =
std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {
return cotan_map.find(s.id()) != cotan_map.end();
});
if (!has_cotans) {
continue;
}
auto cotan = cotan_map.extract(cotan_it).mapped();
auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums);
auto s = a.primitive().stream();
std::vector<array> cotangents{};
for (auto& o : outputs) {
if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {
cotangents.push_back(cotan_map.extract(cotan_it).mapped());
} else {
cotangents.push_back(zeros_like(o, s));
}
}
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums);
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
@ -411,6 +471,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
// Set the global tracing flag.
detail::InTracing in_tracing;
if (primals.size() != tangents.size()) {
throw std::invalid_argument(
"[jvp] Number of inputs does not match number of tangents.");
@ -422,9 +485,6 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
}
}
// Set the global tracing flag.
detail::InTracing in_tracing;
std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
@ -448,36 +508,44 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
std::function<void(array&)> recurse;
recurse = [&](auto& a) {
auto id = a.id();
a.set_tracer(false);
// Check if visited and add to cache if not
if (auto inserted = cache.insert(id); !inserted.second) {
if (auto inserted = cache.insert(a.id()); !inserted.second) {
return;
}
a.set_tracer(false);
for (auto s : a.siblings()) {
s.set_tracer(false);
cache.insert(s.id());
}
for (auto& input : a.editable_inputs()) {
for (auto input : a.inputs()) {
recurse(input);
}
// Stop grad
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
return;
if (a.has_primitive()) {
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
return;
}
}
// Calculate gradient if any inputs require gradient
for (auto& input : a.inputs()) {
if (calc_grad.find(input.id()) != calc_grad.end()) {
tape.push_back(a);
calc_grad.insert(id);
calc_grad.insert(a.id());
for (auto& s : a.siblings()) {
calc_grad.insert(s.id());
}
break;
}
}
};
for (auto& out : outputs) {
for (auto out : outputs) {
recurse(out);
}
std::unordered_map<std::uintptr_t, array> tan_map;
for (int i = 0; i < primals_.size(); ++i) {
tan_map.insert({primals_[i].id(), tangents[i]});
@ -494,8 +562,11 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
}
}
auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums);
tan_map.insert({a.id(), jvp});
auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
auto outputs = a.outputs();
for (int i = 0; i < jvps.size(); ++i) {
tan_map.insert({outputs[i].id(), jvps[i]});
}
}
std::vector<array> jvps;
@ -578,8 +649,8 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
const std::vector<int>& in_axes) {
// Set the global tracing flag
InTracing in_tracing;
// Set the global tracing flag.
detail::InTracing in_tracing;
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
@ -627,48 +698,56 @@ std::vector<array> vmap_replace(
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
std::unordered_set<std::uintptr_t> needs_vmap;
for (int i = 0; i < s_inputs.size(); ++i) {
if (in_axes[i] != -1) {
tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}});
needs_vmap.insert(s_inputs[i].id());
}
}
// Topologically sort the graph
std::unordered_set<std::uintptr_t> cache;
for (int i = 0; i < s_inputs.size(); ++i) {
auto in = s_inputs[i];
if (in_axes[i] != -1) {
tmap.insert({in.id(), {inputs[i], in_axes[i]}});
needs_vmap.insert(in.id());
in.set_tracer(false);
}
cache.insert(in.id());
}
// Topologically sort the graph
std::vector<array> tape;
std::function<void(const array&)> recurse;
recurse = [&](const array& a) {
// Stop at inputs to the vmap function
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
// Recurse on inputs
for (auto& input : a.inputs()) {
recurse(input);
}
cache.insert(id);
// If any input needs a vmap, then the outputs also need
// a vmap
for (auto& input : a.inputs()) {
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
needs_vmap.insert(id);
tape.push_back(a);
tape.back().set_tracer(false);
needs_vmap.insert(a.id());
for (auto s : a.siblings()) {
needs_vmap.insert(s.id());
s.set_tracer(false);
}
break;
}
}
};
for (auto& out : s_outputs) {
recurse(out);
if (out.has_primitive()) {
recurse(out);
}
}
// Transform each primitive in the graph with
@ -686,16 +765,19 @@ std::vector<array> vmap_replace(
v_axes.push_back(-1);
}
}
auto out_and_axis = a.primitive().vmap(v_inputs, v_axes);
tmap.insert({a.id(), out_and_axis});
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs();
for (int i = 0; i < v_outputs.size(); ++i) {
tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
}
}
// Populate the outputs and make sure all the output axes are
// in the right place
std::vector<array> outputs;
for (int i = 0; i < s_outputs.size(); ++i) {
auto map_it = tmap.find(s_outputs[i].id());
if (map_it != tmap.end()) {
if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
auto& [out, vdim] = map_it->second;
if (vdim != out_axes[i]) {
if (out_axes[i] >= out.ndim()) {
@ -704,11 +786,7 @@ std::vector<array> vmap_replace(
<< out.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
std::vector<int> reorder(out.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + vdim);
reorder.insert(reorder.begin() + out_axes[i], vdim);
out = transpose(out, reorder);
out = moveaxis(out, vdim, out_axes[i]);
}
outputs.push_back(out);
} else {

View File

@ -303,6 +303,33 @@ void init_ops(py::module_& m) {
Returns:
array: The quotient ``a / b``.
)pbdoc");
m.def(
"divmod",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return divmod(a, b, s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
Element-wise quotient and remainder.
The fuction ``divmod(a, b)`` is equivalent to but faster than
``(a // b, a % b)``. The function uses numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
tuple(array, array): The quotient ``a // b`` and remainder ``a % b``.
)pbdoc");
m.def(
"floor_divide",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@ -0,0 +1,37 @@
# Copyright © 2023 Apple Inc.
import io
import unittest
import mlx.core as mx
import mlx_tests
class TestGraph(mlx_tests.MLXTestCase):
def test_to_dot(self):
# Simply test that a few cases run.
# Nothing too specific about the graph format
# for now to keep it flexible
a = mx.array(1.0)
f = io.StringIO()
mx.export_to_dot(f, a)
f.seek(0)
self.assertTrue(len(f.read()) > 0)
b = mx.array(2.0)
c = a + b
f = io.StringIO()
mx.export_to_dot(f, c)
f.seek(0)
self.assertTrue(len(f.read()) > 0)
# Multi output case
c = mx.divmod(a, b)
f = io.StringIO()
mx.export_to_dot(f, *c)
f.seek(0)
self.assertTrue(len(f.read()) > 0)
if __name__ == "__main__":
unittest.main()

View File

@ -1314,7 +1314,7 @@ class TestOps(mlx_tests.MLXTestCase):
for axis in (None, 0, 1, 2):
c_npy = npop(a_npy, axis=axis)
c_mlx = mxop(a_mlx, axis=axis)
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
c1 = mxop(a_mlx, axis=2)
@ -1597,6 +1597,28 @@ class TestOps(mlx_tests.MLXTestCase):
np.outer,
)
def test_divmod(self):
# A few sizes for the inputs with and without broadcasting
sizes = [
((1,), (1,)),
((1,), (10,)),
((10,), (1,)),
((3,), (3,)),
((2, 2, 2), (1, 2, 1)),
((2, 1, 2), (1, 2, 1)),
((2, 2, 2, 2), (2, 2, 2, 2)),
]
types = [np.uint16, np.uint32, np.int32, np.float16, np.float32]
for s1, s2 in sizes:
for t in types:
a_np = np.random.uniform(1, 100, size=s1).astype(t)
b_np = np.random.uniform(1, 100, size=s2).astype(t)
np_out = np.divmod(a_np, b_np)
mx_out = mx.divmod(mx.array(a_np), mx.array(b_np))
self.assertTrue(
np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}"
)
if __name__ == "__main__":
unittest.main()

View File

@ -7,19 +7,29 @@
using namespace mlx::core;
TEST_CASE("test simplify scalars") {
auto a = array({-1.0f, 2.0f});
auto b = maximum(a, array(0.0f));
auto c = maximum(-a, array(0.0f));
auto d = b + c;
simplify({d});
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
{
auto a = array(-1.0f);
auto b = array(-1.0f);
auto c = abs(a);
auto d = abs(b);
simplify({c, d});
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
}
{
auto a = array({-1.0f, 2.0f});
auto b = maximum(a, array(0.0f));
auto c = maximum(-a, array(0.0f));
auto d = b + c;
simplify({d});
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
}
}
TEST_CASE("test simplify") {
auto a = array({1.0f, 2.0f});
auto b = exp(a) + exp(a);
simplify(b);
eval(b);
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
}
@ -27,6 +37,44 @@ TEST_CASE("test no simplify") {
auto a = array({1.0f, 2.0f});
auto b = cos(a) + sin(a);
simplify(b);
eval(b);
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
}
TEST_CASE("test simplify multi output") {
{
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = c[0] + d[0];
auto f = c[1] + d[1];
simplify({e, f});
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
}
{
auto a = array(1.0);
auto b = array(1.0);
auto c = divmod(a, b);
simplify(c);
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
}
// Make sure the output order of multi-output primitives
// is respected in simplification
{
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = stack({c[0], c[1], d[0], d[1]});
simplify(e);
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
}
}

View File

@ -30,7 +30,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
CHECK(array_equal(
CHECK(allclose(
norm(x, 0, false),
array(
{std::sqrt(0 + 3 * 3 + 6 * 6),

View File

@ -2397,3 +2397,58 @@ TEST_CASE("inner") {
expected = array({7., 0., 0., 7.}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("test divmod") {
auto x = array({1, 2, 3});
auto y = array({1, 1, 1});
auto out = divmod(x, y);
CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());
CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());
x = array({5, 6, 7});
y = array({2, 2, 2});
out = divmod(x, y);
CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());
CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());
// Siblings should be gone after evaling the graph
CHECK(out[0].siblings().empty());
CHECK(out[1].siblings().empty());
x = array({5.0, 6.0, 7.0});
y = array({2.0, 2.0, 2.0});
out = divmod(x, y);
CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());
CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());
x = array({1.0}, complex64);
y = array({2.0}, complex64);
CHECK_THROWS(divmod(x, y));
// Check that we can eval on both outputs
x = array({1.0});
y = array({2.0});
out = divmod(x, y);
eval(out);
CHECK_EQ(out[0].item<float>(), 0.0);
CHECK_EQ(out[1].item<float>(), 1.0);
// Check nested in the graph
x = array({1.0});
y = array({2.0});
out = divmod(x, y);
auto z = out[0] + out[1];
CHECK_EQ(z.item<float>(), 1.0);
// Check that we can still eval when one output goes out of scope
std::vector<array> out_holder;
{ out_holder.push_back(divmod(x, y)[0]); }
eval(out_holder);
CHECK_EQ(out_holder[0].item<float>(), 0.0);
// Check that we can still eval when the other output goes out of scope
out_holder.clear();
{ out_holder.push_back(divmod(x, y)[1]); }
eval(out_holder);
CHECK_EQ(out_holder[0].item<float>(), 1.0);
}