mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Multi output primitives (#330)
* Multi-output primitives --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
f45f70f133
commit
f099ebe535
@ -233,6 +233,20 @@ void time_gather_scatter() {
|
||||
TIME(single_element_add);
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
@ -246,4 +260,5 @@ int main() {
|
||||
time_matmul();
|
||||
time_reductions();
|
||||
time_gather_scatter();
|
||||
time_divmod();
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ Operations
|
||||
cosh
|
||||
dequantize
|
||||
divide
|
||||
divmod
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
|
@ -39,7 +39,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
@ -47,6 +47,23 @@ array::array(
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
@ -66,6 +83,8 @@ array::array(
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
@ -127,7 +146,7 @@ array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
@ -139,10 +158,6 @@ array::ArrayDesc::ArrayDesc(
|
||||
}
|
||||
}
|
||||
|
||||
// Needed because the Primitive type used in array.h is incomplete and the
|
||||
// compiler needs to see the call to the destructor after the type is complete.
|
||||
array::ArrayDesc::~ArrayDesc() = default;
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
|
51
mlx/array.h
51
mlx/array.h
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
@ -174,7 +173,13 @@ class array {
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
@ -182,6 +187,11 @@ class array {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
}
|
||||
|
||||
/** A unique identifier for an arrays primitive. */
|
||||
std::uintptr_t primitive_id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
|
||||
}
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
@ -219,12 +229,32 @@ class array {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
|
||||
/** A non-const reference to the array's inputs so that they can be used to
|
||||
* edit the graph. */
|
||||
std::vector<array>& editable_inputs() {
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
auto idx = array_desc_->position;
|
||||
std::vector<array> outputs;
|
||||
outputs.reserve(siblings().size() + 1);
|
||||
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
||||
@ -299,7 +329,7 @@ class array {
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::unique_ptr<Primitive> primitive{nullptr};
|
||||
std::shared_ptr<Primitive> primitive{nullptr};
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
@ -321,16 +351,19 @@ class array {
|
||||
Flags flags;
|
||||
|
||||
std::vector<array> inputs;
|
||||
// An array to keep track of the siblings from a multi-output
|
||||
// primitive.
|
||||
std::vector<array> siblings;
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
~ArrayDesc();
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
|
@ -17,6 +17,12 @@
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
@ -57,6 +63,7 @@ DEFAULT(Slice)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -75,6 +76,61 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
|
@ -73,6 +73,12 @@ struct UseDefaultBinaryOp {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@ -89,6 +95,18 @@ struct DefaultVectorScalar {
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, scalar);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@ -105,6 +123,18 @@ struct DefaultScalarVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(scalar, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@ -121,6 +151,18 @@ struct DefaultVectorVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
|
536
mlx/backend/common/binary_two.h
Normal file
536
mlx/backend/common/binary_two.h
Normal file
@ -0,0 +1,536 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out_a.size(); ++i) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[i] = dst.first;
|
||||
dst_b[i] = dst.second;
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
dst_a[out_idx] = dst.first;
|
||||
dst_b[out_idx++] = dst.second;
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op) {
|
||||
switch (out_a.ndim()) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst_a = out_a.data<U>();
|
||||
U* dst_b = out_b.data<U>();
|
||||
for (size_t i = 0; i < out_a.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||
dst_a += stride;
|
||||
dst_b += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||
op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
opsv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
opvs(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
opvv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
out_a.data<U>(),
|
||||
out_b.data<U>(),
|
||||
out_a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out_a.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
|
||||
auto ndim = out_a.ndim();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
op,
|
||||
opsv,
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Ops... ops) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@ -16,6 +16,12 @@
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
@ -89,6 +95,7 @@ DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
|
@ -14,6 +14,7 @@ set(
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"gemm"
|
||||
|
259
mlx/backend/metal/kernels/binary_two.metal
Normal file
259
mlx/backend/metal/kernels/binary_two.metal
Normal file
@ -0,0 +1,259 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
template <> float operator()(float x, float y) { return trunc(x / y); }
|
||||
template <> half operator()(half x, half y) { return trunc(x / y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[index]);
|
||||
d[index] = Op2()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[0]);
|
||||
d[index] = Op2()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[index]);
|
||||
d[index] = Op2()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[index] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
|
||||
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
|
||||
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
|
||||
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder)
|
@ -4,7 +4,6 @@
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@ -54,7 +53,8 @@ std::function<void()> make_task(
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
@ -62,6 +62,9 @@ std::function<void()> make_task(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
for (auto s : arr.siblings()) {
|
||||
s.detach();
|
||||
}
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
|
@ -19,6 +19,98 @@ namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
if (bopt == General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@ -364,6 +456,12 @@ void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
@ -2,6 +2,12 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
}
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
@ -81,5 +87,6 @@ NO_GPU(Subtract)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -12,13 +12,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using OptionalArrayRef = std::optional<std::reference_wrapper<const array>>;
|
||||
|
||||
struct ArrayNames {
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
std::string get_name(uintptr_t id) {
|
||||
auto it = names.find(id);
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
// [A, B, ..., Z, AA, AB, ...]
|
||||
@ -29,45 +27,42 @@ struct ArrayNames {
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
names.insert({id, name});
|
||||
return name;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
return get_name(x.id());
|
||||
}
|
||||
};
|
||||
|
||||
void depth_first_traversal(
|
||||
std::function<void(OptionalArrayRef, const array&, int)> callback,
|
||||
std::function<void(array)> callback,
|
||||
const std::vector<array>& outputs) {
|
||||
std::function<void(OptionalArrayRef, const array&, int)> recurse;
|
||||
std::function<void(const array&)> recurse;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
recurse = [&](OptionalArrayRef parent, const array& x, int input_index) {
|
||||
recurse = [&](const array& x) {
|
||||
auto id = x.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
cache.insert(id);
|
||||
for (int i = 0; i < x.inputs().size(); i++) {
|
||||
recurse(x, x.inputs()[i], i);
|
||||
for (auto& s : x.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
callback(parent, x, input_index);
|
||||
for (auto& in : x.inputs()) {
|
||||
recurse(in);
|
||||
}
|
||||
callback(x);
|
||||
};
|
||||
|
||||
for (auto x : outputs) {
|
||||
recurse(std::nullopt, x, 0);
|
||||
for (auto& o : outputs) {
|
||||
recurse(o);
|
||||
}
|
||||
}
|
||||
|
||||
void depth_first_traversal(
|
||||
std::function<void(const array&)> callback,
|
||||
const std::vector<array>& outputs) {
|
||||
depth_first_traversal(
|
||||
[&callback](OptionalArrayRef p, const array& x, int input_index) {
|
||||
callback(x);
|
||||
},
|
||||
outputs);
|
||||
}
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||
std::vector<array> tape;
|
||||
std::vector<array> inputs;
|
||||
@ -82,15 +77,11 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||
},
|
||||
outputs);
|
||||
|
||||
ArrayNames namer;
|
||||
auto print_arr = [&namer, &os](const array& a) {
|
||||
os << namer.get_name(a);
|
||||
os << " [" << a.shape() << ", " << a.dtype() << "]";
|
||||
};
|
||||
|
||||
auto print_arrs = [&](const std::vector<array>& arrs) {
|
||||
NodeNamer namer;
|
||||
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
|
||||
for (auto& arr : arrs) {
|
||||
print_arr(arr);
|
||||
os << namer.get_name(arr);
|
||||
os << " [" << arr.shape() << ", " << arr.dtype() << "]";
|
||||
if (&arr != &arrs.back()) {
|
||||
os << ", ";
|
||||
}
|
||||
@ -108,7 +99,7 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||
os << " ";
|
||||
print_arrs(arr.inputs());
|
||||
os << " -> ";
|
||||
print_arr(arr);
|
||||
print_arrs(arr.outputs());
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
@ -116,26 +107,47 @@ void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||
os << "digraph {" << std::endl;
|
||||
|
||||
ArrayNames namer;
|
||||
std::unordered_set<std::uintptr_t> output_set;
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
std::unordered_set<std::uintptr_t> input_set;
|
||||
NodeNamer namer;
|
||||
depth_first_traversal(
|
||||
[&namer, &os](auto parent, const array& x, int input_index) {
|
||||
os << "{ ";
|
||||
[&](const array& x) {
|
||||
if (!x.has_primitive()) {
|
||||
os << "rank=source; ";
|
||||
input_set.insert(x.id());
|
||||
os << "{ rank=source; " << namer.get_name(x) << "; }" << std::endl;
|
||||
return;
|
||||
}
|
||||
if (!parent) {
|
||||
os << "rank=sink; ";
|
||||
}
|
||||
os << namer.get_name(x);
|
||||
|
||||
// Node for primitive
|
||||
if (x.has_primitive()) {
|
||||
os << "{ ";
|
||||
os << namer.get_name(x.primitive_id());
|
||||
os << " [label =\"";
|
||||
x.primitive().print(os);
|
||||
os << "\"]";
|
||||
os << "\", shape=rectangle]";
|
||||
os << "; }" << std::endl;
|
||||
// Arrows to primitive's inputs
|
||||
for (auto& a : x.inputs()) {
|
||||
os << namer.get_name(x.primitive_id()) << " -> "
|
||||
<< namer.get_name(a) << std::endl;
|
||||
}
|
||||
}
|
||||
os << "; }" << std::endl;
|
||||
|
||||
for (auto c : x.inputs()) {
|
||||
os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl;
|
||||
// Point outputs to their primitive
|
||||
for (auto& a : x.outputs()) {
|
||||
os << "{ ";
|
||||
if (output_set.find(a.id()) != output_set.end()) {
|
||||
os << "rank=sink; ";
|
||||
}
|
||||
os << namer.get_name(a);
|
||||
os << "; }" << std::endl;
|
||||
if (x.has_primitive()) {
|
||||
os << namer.get_name(a) << " -> "
|
||||
<< namer.get_name(x.primitive_id()) << std::endl;
|
||||
}
|
||||
}
|
||||
},
|
||||
outputs);
|
||||
|
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -1737,6 +1737,21 @@ array operator%(const array& a, const array& b) {
|
||||
return remainder(a, b);
|
||||
}
|
||||
|
||||
std::vector<array>
|
||||
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_complex(dtype)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
return array::make_arrays(
|
||||
{inputs[0].shape(), inputs[0].shape()},
|
||||
{inputs[0].dtype(), inputs[0].dtype()},
|
||||
std::make_unique<DivMod>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
|
||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
|
@ -721,6 +721,10 @@ array operator/(const array& a, const array& b);
|
||||
array operator/(double a, const array& b);
|
||||
array operator/(const array& a, double b);
|
||||
|
||||
/** Compute the element-wise quotient and remainder. */
|
||||
std::vector<array>
|
||||
divmod(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Compute integer division. Equivalent to doing floor(a / x). */
|
||||
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
724
mlx/primitives.h
724
mlx/primitives.h
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <future>
|
||||
#include <map>
|
||||
@ -26,6 +25,16 @@ namespace mlx::core {
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void simplify(const std::vector<array>& outputs) {
|
||||
// Some notes about how this function works
|
||||
//
|
||||
// Step 1: Traverse the graph and build a tape. During the graph
|
||||
// traversal we:
|
||||
// - Build a map of inputs to their parents.
|
||||
// - Record scalar inputs in a map in order to fuse them.
|
||||
// Step 2: Process the tape. A node in the tape has inputs and outputs.
|
||||
// - Scalar inputs are replaced with their canonical scalar
|
||||
// - We check each inputs output nodes. Every output node that matches
|
||||
// the current node gets fused into the current node.
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@ -54,7 +63,7 @@ void simplify(const std::vector<array>& outputs) {
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
};
|
||||
|
||||
// DFS the graph to log the parents
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
@ -63,9 +72,16 @@ void simplify(const std::vector<array>& outputs) {
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
recurse(in);
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
||||
tape.push(a);
|
||||
if (is_scalar(a)) {
|
||||
scalars.insert({get_scalar_rep(a), a});
|
||||
@ -78,26 +94,33 @@ void simplify(const std::vector<array>& outputs) {
|
||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
auto fuse = [&](array& dst, array& src) {
|
||||
auto src_parents = parents_map.find(src.id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& pairs = parents_map[dst.id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.editable_inputs()[parent.second] = dst;
|
||||
pairs.push_back(parent);
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dest
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
// Walk the graph
|
||||
cache.clear();
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
return false;
|
||||
}
|
||||
if (a.primitive_id() == b.primitive_id()) {
|
||||
return false;
|
||||
}
|
||||
const auto& pa = a.primitive();
|
||||
const auto& pb = b.primitive();
|
||||
if (typeid(pa) != typeid(pb)) {
|
||||
@ -117,14 +140,11 @@ void simplify(const std::vector<array>& outputs) {
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
// Walk the graph
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
|
||||
if (cache.find(arr.id()) != cache.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if we can fuse scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
@ -134,28 +154,35 @@ void simplify(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we can fuse the parents of this array
|
||||
auto parents = parents_map.find(arr.id());
|
||||
if (parents != parents_map.end()) {
|
||||
std::vector<bool> mask(parents->second.size(), false);
|
||||
auto N = parents->second.size();
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
// Helper to check if we can fuse the parents of the
|
||||
// given array
|
||||
auto maybe_fuse_parents = [&](auto& a) {
|
||||
auto parents = parents_map.find(a.id());
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
std::vector<bool> mask(N, false);
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
cache.insert(src.id());
|
||||
fuse(dst, src);
|
||||
mask[j] = true;
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
fuse(dst, src);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
maybe_fuse_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
maybe_fuse_parents(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -177,11 +204,14 @@ void eval(const std::vector<array>& outputs) {
|
||||
// stream, we need to manage the dependency.
|
||||
if (!in.is_evaled()) {
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.id(), std::shared_future<void>{}});
|
||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
@ -191,17 +221,23 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
};
|
||||
|
||||
// We have to store the output primitive ids because the arrays are
|
||||
// detached during eval and we need to use them for synchronization
|
||||
// at the end of this function
|
||||
std::vector<std::uintptr_t> output_primitive_ids;
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||
recurse(arr);
|
||||
// Insert a dependency for every output to synchronize
|
||||
// with at the end.
|
||||
if (!arr.is_evaled()) {
|
||||
deps.insert({arr.id(), std::shared_future<void>{}});
|
||||
if (!arr.is_evaled() && deps.find(arr.primitive_id()) == deps.end()) {
|
||||
deps.insert({arr.primitive_id(), std::shared_future<void>{}});
|
||||
output_primitive_ids.push_back(arr.primitive_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
@ -215,13 +251,14 @@ void eval(const std::vector<array>& outputs) {
|
||||
auto stream = arr.primitive().stream();
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
for (auto& in : arr.inputs()) {
|
||||
if (auto it = deps.find(in.id()); it != deps.end()) {
|
||||
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
|
||||
arr_deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<std::promise<void>> p{nullptr};
|
||||
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
|
||||
p = std::make_unique<std::promise<void>>();
|
||||
ps.push_back(p);
|
||||
it->second = p->get_future().share();
|
||||
}
|
||||
|
||||
@ -234,15 +271,19 @@ void eval(const std::vector<array>& outputs) {
|
||||
} else {
|
||||
auto task = [arr,
|
||||
stream,
|
||||
arr_deps = std::move(arr_deps),
|
||||
deps = std::move(arr_deps),
|
||||
p = std::move(p)]() mutable {
|
||||
for (auto& d : arr_deps) {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
for (auto s : arr.siblings()) {
|
||||
s.detach();
|
||||
}
|
||||
}
|
||||
if (p) {
|
||||
p->set_value();
|
||||
@ -252,10 +293,8 @@ void eval(const std::vector<array>& outputs) {
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
}
|
||||
}
|
||||
for (auto& arr : outputs) {
|
||||
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||
it->second.wait();
|
||||
}
|
||||
for (auto id : output_primitive_ids) {
|
||||
deps[id].wait();
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,8 +340,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
output_cotan_pairs.emplace_back(i, cotan_index++);
|
||||
}
|
||||
|
||||
// Topologically sort the compute graph, record outputs
|
||||
// in the tape if a gradient is needed.
|
||||
// Topologically sort the compute graph, add graph nodes
|
||||
// to the tape which need a gradient.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (auto& primal : primals_) {
|
||||
@ -315,34 +354,41 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
||||
std::function<void(array&)> recurse;
|
||||
recurse = [&](auto& a) {
|
||||
auto id = a.id();
|
||||
a.set_tracer(false);
|
||||
|
||||
// Check if visited and add to cache if not
|
||||
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||
if (auto inserted = cache.insert(a.id()); !inserted.second) {
|
||||
return;
|
||||
}
|
||||
a.set_tracer(false);
|
||||
for (auto s : a.siblings()) {
|
||||
s.set_tracer(false);
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
||||
for (auto& input : a.editable_inputs()) {
|
||||
for (auto& input : a.inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
|
||||
// Stop grad
|
||||
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||
return;
|
||||
if (a.has_primitive()) {
|
||||
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate gradient if any inputs require gradient
|
||||
for (auto& input : a.inputs()) {
|
||||
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||
tape.push_back(a);
|
||||
calc_grad.insert(id);
|
||||
calc_grad.insert(a.id());
|
||||
for (auto& s : a.siblings()) {
|
||||
calc_grad.insert(s.id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : outputs) {
|
||||
for (auto out : outputs) {
|
||||
recurse(out);
|
||||
}
|
||||
|
||||
@ -363,14 +409,28 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto cotan_it = cotan_map.find(a.id());
|
||||
if (cotan_it == cotan_map.end()) {
|
||||
// Check if any of the array or its siblings have cotangents,
|
||||
// if not, we can skip this primitive
|
||||
auto outputs = a.outputs();
|
||||
bool has_cotans =
|
||||
std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {
|
||||
return cotan_map.find(s.id()) != cotan_map.end();
|
||||
});
|
||||
if (!has_cotans) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cotan = cotan_map.extract(cotan_it).mapped();
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums);
|
||||
auto s = a.primitive().stream();
|
||||
std::vector<array> cotangents{};
|
||||
for (auto& o : outputs) {
|
||||
if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {
|
||||
cotangents.push_back(cotan_map.extract(cotan_it).mapped());
|
||||
} else {
|
||||
cotangents.push_back(zeros_like(o, s));
|
||||
}
|
||||
}
|
||||
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums);
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
@ -411,6 +471,9 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
if (primals.size() != tangents.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[jvp] Number of inputs does not match number of tangents.");
|
||||
@ -422,9 +485,6 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
}
|
||||
}
|
||||
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
auto s = p.has_primitive() ? p.primitive().stream()
|
||||
@ -448,36 +508,44 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
|
||||
std::function<void(array&)> recurse;
|
||||
recurse = [&](auto& a) {
|
||||
auto id = a.id();
|
||||
a.set_tracer(false);
|
||||
|
||||
// Check if visited and add to cache if not
|
||||
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||
if (auto inserted = cache.insert(a.id()); !inserted.second) {
|
||||
return;
|
||||
}
|
||||
a.set_tracer(false);
|
||||
for (auto s : a.siblings()) {
|
||||
s.set_tracer(false);
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
||||
for (auto& input : a.editable_inputs()) {
|
||||
for (auto input : a.inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
|
||||
// Stop grad
|
||||
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||
return;
|
||||
if (a.has_primitive()) {
|
||||
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate gradient if any inputs require gradient
|
||||
for (auto& input : a.inputs()) {
|
||||
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||
tape.push_back(a);
|
||||
calc_grad.insert(id);
|
||||
calc_grad.insert(a.id());
|
||||
for (auto& s : a.siblings()) {
|
||||
calc_grad.insert(s.id());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : outputs) {
|
||||
for (auto out : outputs) {
|
||||
recurse(out);
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, array> tan_map;
|
||||
for (int i = 0; i < primals_.size(); ++i) {
|
||||
tan_map.insert({primals_[i].id(), tangents[i]});
|
||||
@ -494,8 +562,11 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
}
|
||||
}
|
||||
|
||||
auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums);
|
||||
tan_map.insert({a.id(), jvp});
|
||||
auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < jvps.size(); ++i) {
|
||||
tan_map.insert({outputs[i].id(), jvps[i]});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> jvps;
|
||||
@ -578,8 +649,8 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& in_axes) {
|
||||
// Set the global tracing flag
|
||||
InTracing in_tracing;
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
@ -627,48 +698,56 @@ std::vector<array> vmap_replace(
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}});
|
||||
needs_vmap.insert(s_inputs[i].id());
|
||||
}
|
||||
}
|
||||
|
||||
// Topologically sort the graph
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||
auto in = s_inputs[i];
|
||||
if (in_axes[i] != -1) {
|
||||
tmap.insert({in.id(), {inputs[i], in_axes[i]}});
|
||||
needs_vmap.insert(in.id());
|
||||
in.set_tracer(false);
|
||||
}
|
||||
cache.insert(in.id());
|
||||
}
|
||||
|
||||
// Topologically sort the graph
|
||||
std::vector<array> tape;
|
||||
|
||||
std::function<void(const array&)> recurse;
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
// Stop at inputs to the vmap function
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
||||
// Recurse on inputs
|
||||
for (auto& input : a.inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
cache.insert(id);
|
||||
// If any input needs a vmap, then the outputs also need
|
||||
// a vmap
|
||||
for (auto& input : a.inputs()) {
|
||||
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
|
||||
needs_vmap.insert(id);
|
||||
tape.push_back(a);
|
||||
tape.back().set_tracer(false);
|
||||
needs_vmap.insert(a.id());
|
||||
for (auto s : a.siblings()) {
|
||||
needs_vmap.insert(s.id());
|
||||
s.set_tracer(false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : s_outputs) {
|
||||
recurse(out);
|
||||
if (out.has_primitive()) {
|
||||
recurse(out);
|
||||
}
|
||||
}
|
||||
|
||||
// Transform each primitive in the graph with
|
||||
@ -686,16 +765,19 @@ std::vector<array> vmap_replace(
|
||||
v_axes.push_back(-1);
|
||||
}
|
||||
}
|
||||
auto out_and_axis = a.primitive().vmap(v_inputs, v_axes);
|
||||
tmap.insert({a.id(), out_and_axis});
|
||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||
// For each primitive's outputs add its id, the vout id and the vax
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < v_outputs.size(); ++i) {
|
||||
tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
|
||||
}
|
||||
}
|
||||
|
||||
// Populate the outputs and make sure all the output axes are
|
||||
// in the right place
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < s_outputs.size(); ++i) {
|
||||
auto map_it = tmap.find(s_outputs[i].id());
|
||||
if (map_it != tmap.end()) {
|
||||
if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
|
||||
auto& [out, vdim] = map_it->second;
|
||||
if (vdim != out_axes[i]) {
|
||||
if (out_axes[i] >= out.ndim()) {
|
||||
@ -704,11 +786,7 @@ std::vector<array> vmap_replace(
|
||||
<< out.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<int> reorder(out.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + vdim);
|
||||
reorder.insert(reorder.begin() + out_axes[i], vdim);
|
||||
out = transpose(out, reorder);
|
||||
out = moveaxis(out, vdim, out_axes[i]);
|
||||
}
|
||||
outputs.push_back(out);
|
||||
} else {
|
||||
|
@ -303,6 +303,33 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The quotient ``a / b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"divmod",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return divmod(a, b, s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise quotient and remainder.
|
||||
|
||||
The fuction ``divmod(a, b)`` is equivalent to but faster than
|
||||
``(a // b, a % b)``. The function uses numpy-style broadcasting
|
||||
semantics. Either or both input arrays can also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
tuple(array, array): The quotient ``a // b`` and remainder ``a % b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"floor_divide",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
37
python/tests/test_graph.py
Normal file
37
python/tests/test_graph.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestGraph(mlx_tests.MLXTestCase):
|
||||
def test_to_dot(self):
|
||||
# Simply test that a few cases run.
|
||||
# Nothing too specific about the graph format
|
||||
# for now to keep it flexible
|
||||
a = mx.array(1.0)
|
||||
f = io.StringIO()
|
||||
mx.export_to_dot(f, a)
|
||||
f.seek(0)
|
||||
self.assertTrue(len(f.read()) > 0)
|
||||
|
||||
b = mx.array(2.0)
|
||||
c = a + b
|
||||
f = io.StringIO()
|
||||
mx.export_to_dot(f, c)
|
||||
f.seek(0)
|
||||
self.assertTrue(len(f.read()) > 0)
|
||||
|
||||
# Multi output case
|
||||
c = mx.divmod(a, b)
|
||||
f = io.StringIO()
|
||||
mx.export_to_dot(f, *c)
|
||||
f.seek(0)
|
||||
self.assertTrue(len(f.read()) > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1314,7 +1314,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for axis in (None, 0, 1, 2):
|
||||
c_npy = npop(a_npy, axis=axis)
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
for op in ["cumsum", "cumprod", "cummax", "cummin"]:
|
||||
c1 = mxop(a_mlx, axis=2)
|
||||
@ -1597,6 +1597,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
np.outer,
|
||||
)
|
||||
|
||||
def test_divmod(self):
|
||||
# A few sizes for the inputs with and without broadcasting
|
||||
sizes = [
|
||||
((1,), (1,)),
|
||||
((1,), (10,)),
|
||||
((10,), (1,)),
|
||||
((3,), (3,)),
|
||||
((2, 2, 2), (1, 2, 1)),
|
||||
((2, 1, 2), (1, 2, 1)),
|
||||
((2, 2, 2, 2), (2, 2, 2, 2)),
|
||||
]
|
||||
types = [np.uint16, np.uint32, np.int32, np.float16, np.float32]
|
||||
for s1, s2 in sizes:
|
||||
for t in types:
|
||||
a_np = np.random.uniform(1, 100, size=s1).astype(t)
|
||||
b_np = np.random.uniform(1, 100, size=s2).astype(t)
|
||||
np_out = np.divmod(a_np, b_np)
|
||||
mx_out = mx.divmod(mx.array(a_np), mx.array(b_np))
|
||||
self.assertTrue(
|
||||
np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -7,19 +7,29 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
{
|
||||
auto a = array(-1.0f);
|
||||
auto b = array(-1.0f);
|
||||
auto c = abs(a);
|
||||
auto d = abs(b);
|
||||
simplify({c, d});
|
||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
@ -27,6 +37,44 @@ TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify multi output") {
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = c[0] + d[0];
|
||||
auto f = c[1] + d[1];
|
||||
|
||||
simplify({e, f});
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(1.0);
|
||||
auto c = divmod(a, b);
|
||||
simplify(c);
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||
}
|
||||
|
||||
// Make sure the output order of multi-output primitives
|
||||
// is respected in simplification
|
||||
{
|
||||
auto a = array(1.0);
|
||||
auto b = array(2.0);
|
||||
auto c = divmod(a, b);
|
||||
auto d = divmod(a, b);
|
||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||
simplify(e);
|
||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(
|
||||
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||
CHECK(array_equal(
|
||||
CHECK(allclose(
|
||||
norm(x, 0, false),
|
||||
array(
|
||||
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||
|
@ -2397,3 +2397,58 @@ TEST_CASE("inner") {
|
||||
expected = array({7., 0., 0., 7.}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test divmod") {
|
||||
auto x = array({1, 2, 3});
|
||||
auto y = array({1, 1, 1});
|
||||
auto out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());
|
||||
|
||||
x = array({5, 6, 7});
|
||||
y = array({2, 2, 2});
|
||||
out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());
|
||||
|
||||
// Siblings should be gone after evaling the graph
|
||||
CHECK(out[0].siblings().empty());
|
||||
CHECK(out[1].siblings().empty());
|
||||
|
||||
x = array({5.0, 6.0, 7.0});
|
||||
y = array({2.0, 2.0, 2.0});
|
||||
out = divmod(x, y);
|
||||
CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());
|
||||
CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());
|
||||
|
||||
x = array({1.0}, complex64);
|
||||
y = array({2.0}, complex64);
|
||||
CHECK_THROWS(divmod(x, y));
|
||||
|
||||
// Check that we can eval on both outputs
|
||||
x = array({1.0});
|
||||
y = array({2.0});
|
||||
out = divmod(x, y);
|
||||
eval(out);
|
||||
CHECK_EQ(out[0].item<float>(), 0.0);
|
||||
CHECK_EQ(out[1].item<float>(), 1.0);
|
||||
|
||||
// Check nested in the graph
|
||||
x = array({1.0});
|
||||
y = array({2.0});
|
||||
out = divmod(x, y);
|
||||
auto z = out[0] + out[1];
|
||||
CHECK_EQ(z.item<float>(), 1.0);
|
||||
|
||||
// Check that we can still eval when one output goes out of scope
|
||||
std::vector<array> out_holder;
|
||||
{ out_holder.push_back(divmod(x, y)[0]); }
|
||||
eval(out_holder);
|
||||
CHECK_EQ(out_holder[0].item<float>(), 0.0);
|
||||
|
||||
// Check that we can still eval when the other output goes out of scope
|
||||
out_holder.clear();
|
||||
{ out_holder.push_back(divmod(x, y)[1]); }
|
||||
eval(out_holder);
|
||||
CHECK_EQ(out_holder[0].item<float>(), 1.0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user