reduce binary size (#1952)

This commit is contained in:
Awni Hannun
2025-03-11 06:30:44 -07:00
committed by GitHub
parent 117e1355a2
commit 736a340478
16 changed files with 2145 additions and 2386 deletions

View File

@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs; return outputs;
} }
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data) array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},

View File

@@ -199,6 +199,13 @@ class array {
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */ /** A unique identifier for an array. */
std::uintptr_t id() const { std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get()); return reinterpret_cast<std::uintptr_t>(array_desc_.get());

View File

@@ -11,12 +11,7 @@ namespace mlx::core {
namespace { namespace {
template <typename InT, typename OpT> template <typename InT, typename OpT>
void arg_reduce( void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
const array& in,
array& out,
const OpT& op,
int axis,
Stream stream) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
Strides strides = in.strides(); Strides strides = in.strides();
@@ -26,18 +21,7 @@ void arg_reduce(
auto in_ptr = in.data<InT>(); auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>(); auto out_ptr = out.data<uint32_t>();
auto& encoder = cpu::get_command_encoder(stream); for (uint32_t i = 0; i < out.size(); ++i) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in_ptr,
out_ptr,
axis_size,
axis_stride,
op = std::move(op),
shape = std::move(shape),
strides = std::move(strides),
size = out.size()]() {
for (uint32_t i = 0; i < size; ++i) {
auto loc = elem_to_loc(i, shape, strides); auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc; auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0; uint32_t ind_v = 0;
@@ -47,7 +31,6 @@ void arg_reduce(
} }
out_ptr[i] = ind_v; out_ptr[i] = ind_v;
} }
});
} }
template <typename InT> template <typename InT>
@@ -55,8 +38,7 @@ void arg_reduce_dispatch(
const array& in, const array& in,
array& out, array& out,
ArgReduce::ReduceType rtype, ArgReduce::ReduceType rtype,
int axis, int axis) {
Stream stream) {
switch (rtype) { switch (rtype) {
case ArgReduce::ArgMin: { case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) { auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
@@ -65,7 +47,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x; (*ind_y) = ind_x;
} }
}; };
arg_reduce<InT>(in, out, op, axis, stream); arg_reduce<InT>(in, out, op, axis);
break; break;
} }
case ArgReduce::ArgMax: { case ArgReduce::ArgMax: {
@@ -75,7 +57,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x; (*ind_y) = ind_x;
} }
}; };
arg_reduce<InT>(in, out, op, axis, stream); arg_reduce<InT>(in, out, op, axis);
break; break;
} }
} }
@@ -87,51 +69,58 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break; break;
case uint8: case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break; break;
case uint16: case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break; break;
case uint32: case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break; break;
case uint64: case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break; break;
case int8: case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break; break;
case int16: case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break; break;
case int32: case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break; break;
case int64: case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break; break;
case float16: case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break; break;
case float32: case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break; break;
case bfloat16: case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break; break;
case float64: case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break; break;
case complex64: case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break; break;
} }
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h" #include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -16,51 +17,218 @@ namespace mlx::core {
namespace { namespace {
template <typename Op> template <typename Op>
void comparison_op(const array& a, const array& b, array& out) { void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
switch (a.dtype()) { auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_: case bool_:
binary_op<bool, bool, Op>(a, b, out); binary_op<bool, Op>(a, b, out, bopt);
break; break;
case uint8: case uint8:
binary_op<uint8_t, bool, Op>(a, b, out); binary_op<uint8_t, Op>(a, b, out, bopt);
break; break;
case uint16: case uint16:
binary_op<uint16_t, bool, Op>(a, b, out); binary_op<uint16_t, Op>(a, b, out, bopt);
break; break;
case uint32: case uint32:
binary_op<uint32_t, bool, Op>(a, b, out); binary_op<uint32_t, Op>(a, b, out, bopt);
break; break;
case uint64: case uint64:
binary_op<uint64_t, bool, Op>(a, b, out); binary_op<uint64_t, Op>(a, b, out, bopt);
break; break;
case int8: case int8:
binary_op<int8_t, bool, Op>(a, b, out); binary_op<int8_t, Op>(a, b, out, bopt);
break; break;
case int16: case int16:
binary_op<int16_t, bool, Op>(a, b, out); binary_op<int16_t, Op>(a, b, out, bopt);
break; break;
case int32: case int32:
binary_op<int32_t, bool, Op>(a, b, out); binary_op<int32_t, Op>(a, b, out, bopt);
break; break;
case int64: case int64:
binary_op<int64_t, bool, Op>(a, b, out); binary_op<int64_t, Op>(a, b, out, bopt);
break; break;
case float16: case float16:
binary_op<float16_t, bool, Op>(a, b, out); binary_op<float16_t, Op>(a, b, out, bopt);
break; break;
case float32: case float32:
binary_op<float, bool, Op>(a, b, out); binary_op<float, Op>(a, b, out, bopt);
break; break;
case float64: case float64:
binary_op<double, bool, Op>(a, b, out); binary_op<double, Op>(a, b, out, bopt);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out); binary_op<bfloat16_t, Op>(a, b, out, bopt);
break; break;
case complex64: case complex64:
binary_op<complex64_t, bool, Op>(a, b, out); binary_op<complex64_t, Op>(a, b, out, bopt);
break; break;
} }
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
} }
} // namespace } // namespace
@@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Add()); binary(a, b, out, detail::Add(), stream());
} }
void DivMod::eval_cpu( void DivMod::eval_cpu(
@@ -78,70 +246,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out_a = array::unsafe_weak_copy(out_a),
out_b = array::unsafe_weak_copy(out_b),
bopt]() mutable {
auto integral_op = [](auto x, auto y) { auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y); return std::make_pair(x / y, x % y);
}; };
auto float_op = [](auto x, auto y) { auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y)); return std::make_pair(std::trunc(x / y), std::fmod(x, y));
}; };
switch (outputs[0].dtype()) {
switch (out_a.dtype()) {
case bool_: case bool_:
binary_op<bool>(a, b, outputs, integral_op); binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case uint8: case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op); binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case uint16: case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op); binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case uint32: case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op); binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case uint64: case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op); binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int8: case int8:
binary_op<int8_t>(a, b, outputs, integral_op); binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int16: case int16:
binary_op<int16_t>(a, b, outputs, integral_op); binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int32: case int32:
binary_op<int32_t>(a, b, outputs, integral_op); binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int64: case int64:
binary_op<int64_t>(a, b, outputs, integral_op); binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float16: case float16:
binary_op<float16_t>(a, b, outputs, float_op); binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break; break;
case float32: case float32:
binary_op<float>(a, b, outputs, float_op); binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break; break;
case float64: case float64:
binary_op<double>(a, b, outputs, float_op); binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op); binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break; break;
case complex64: case complex64:
// Should never get here // Should never get here
throw std::runtime_error("[DivMod] Complex type not supported"); throw std::runtime_error("[DivMod] Complex type not supported");
break; break;
} }
});
} }
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) { void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Divide()); binary(a, b, out, detail::Divide(), stream());
} }
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Remainder()); binary(a, b, out, detail::Remainder(), stream());
} }
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
if (equal_nan_) { if (equal_nan_) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) { switch (a.dtype()) {
case float16: case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out); binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break; break;
case float32: case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out); binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
break; break;
case float64: case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out); binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out); binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break; break;
case complex64: case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out); binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types."); "[NanEqual::eval_cpu] Only for floating point types.");
} }
});
} else { } else {
comparison_op<detail::Equal>(a, b, out); comparison_op(a, b, out, detail::Equal(), stream());
} }
} }
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::Greater>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
} }
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
} }
void Less::eval_cpu(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::Less>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
} }
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
} }
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::LogAddExp(), stream());
case float16:
binary_op<float16_t, detail::LogAddExp>(a, b, out);
break;
case float32:
binary_op<float, detail::LogAddExp>(a, b, out);
break;
case float64:
binary_op<double, detail::LogAddExp>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
} }
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd()); binary(in1, in2, out, detail::LogicalAnd(), stream());
} }
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr()); binary(in1, in2, out, detail::LogicalOr(), stream());
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Maximum()); binary(a, b, out, detail::Maximum(), stream());
} }
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Minimum()); binary(a, b, out, detail::Minimum(), stream());
} }
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Multiply()); binary(a, b, out, detail::Multiply(), stream());
} }
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
} }
void Power::eval_cpu(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Power()); binary(a, b, out, detail::Power(), stream());
} }
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Subtract()); binary(a, b, out, detail::Subtract(), stream());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd()); binary_int(a, b, out, detail::BitwiseAnd(), stream());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr()); binary_int(a, b, out, detail::BitwiseOr(), stream());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor()); binary_int(a, b, out, detail::BitwiseXor(), stream());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift()); binary_int(a, b, out, detail::LeftShift(), stream());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift()); binary_int(a, b, out, detail::RightShift(), stream());
break; break;
} }
} }
@@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::ArcTan2(), stream());
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,12 +3,9 @@
#pragma once #pragma once
#include <cassert> #include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -152,30 +149,12 @@ void binary_op_dispatch_dims(
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out) { void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>(); auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([bopt,
a_ptr,
b_ptr,
out_ptr,
a_data_size = a.data_size(),
b_data_size = b.data_size(),
size = a.size(),
shape = a.shape(),
a_strides = a.strides(),
b_strides = b.strides(),
strides = out.strides()]() mutable {
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
*out_ptr = Op{}(*a_ptr, *b_ptr); *out_ptr = Op{}(*a_ptr, *b_ptr);
return; return;
@@ -183,29 +162,28 @@ void binary_op(const array& a, const array& b, array& out) {
// The full computation is scalar vector so delegate to the op // The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size); ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return; return;
} }
// The full computation is vector scalar so delegate to the op // The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) { if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size); VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return; return;
} }
// The full computation is vector vector so delegate to the op // The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) { if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size); VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return; return;
} }
// General computation so let's try to optimize // General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims( auto [new_shape, new_strides] = collapse_contiguous_dims(
shape, a.shape(), {a.strides(), b.strides(), out.strides()});
{std::move(a_strides), std::move(b_strides), std::move(strides)}); auto& a_strides = new_strides[0];
a_strides = new_strides[0]; auto& b_strides = new_strides[1];
b_strides = new_strides[1]; auto& strides = new_strides[2];
strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) { auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
@@ -262,7 +240,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr, b_ptr,
out_ptr, out_ptr,
dim, dim,
size, a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
@@ -274,7 +252,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr, b_ptr,
out_ptr, out_ptr,
dim, dim,
size, a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
@@ -286,7 +264,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr, b_ptr,
out_ptr, out_ptr,
dim, dim,
size, a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
@@ -298,72 +276,18 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr, b_ptr,
out_ptr, out_ptr,
dim, dim,
size, a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
strides); strides);
break; break;
} }
});
} }
template <typename T, typename Op> template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out) { void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out); binary_op<T, T, Op>(a, b, out, bopt);
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T, Op>(a, b, out);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out);
break;
case int8:
binary_op<int8_t, Op>(a, b, out);
break;
case int16:
binary_op<int16_t, Op>(a, b, out);
break;
case int32:
binary_op<int32_t, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, Op>(a, b, out);
break;
case float32:
binary_op<float, Op>(a, b, out);
break;
case float64:
binary_op<double, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out);
break;
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -4,8 +4,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -57,14 +55,7 @@ void binary_op_dispatch_dims(
const array& b, const array& b,
array& out_a, array& out_a,
array& out_b, array& out_b,
Stream stream,
Op op) { Op op) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()}); a.shape(), {a.strides(), b.strides(), out_a.strides()});
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
@@ -72,14 +63,6 @@ void binary_op_dispatch_dims(
U* out_a_ptr = out_a.data<U>(); U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>(); U* out_b_ptr = out_b.data<U>();
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
shape = std::move(shape),
strides = std::move(strides),
op = std::move(op)]() {
const auto& a_strides = strides[0]; const auto& a_strides = strides[0];
const auto& b_strides = strides[1]; const auto& b_strides = strides[1];
const auto& out_strides = strides[2]; const auto& out_strides = strides[2];
@@ -116,7 +99,7 @@ void binary_op_dispatch_dims(
ContiguousIterator a_it(shape, a_strides, ndim - 2); ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3]; auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < size; elem += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>( binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc, a_ptr + a_it.loc,
b_ptr + b_it.loc, b_ptr + b_it.loc,
@@ -131,138 +114,50 @@ void binary_op_dispatch_dims(
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }
});
} }
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, array& out_a,
Op op) { array& out_b,
auto bopt = get_binary_op_type(a, b); Op op,
auto& out_a = outputs[0]; BinaryOpType bopt) {
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto stream = out_a.primitive().stream();
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op); binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return; return;
} }
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>(); auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>(); auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>(); auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
encoder.dispatch(
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
});
} else if (bopt == BinaryOpType::ScalarVector) { } else if (bopt == BinaryOpType::ScalarVector) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < b.data_size(); ++i) {
b_ptr,
out_a_ptr,
out_b_ptr,
size = b.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
b_ptr++; b_ptr++;
} }
});
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < a.data_size(); ++i) {
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
a_ptr++; a_ptr++;
} }
});
} else { // VectorVector } else { // VectorVector
encoder.dispatch([a_ptr, for (size_t i = 0; i < a.size(); ++i) {
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
a_ptr++; a_ptr++;
b_ptr++; b_ptr++;
} }
});
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
} }
} }

View File

@@ -13,29 +13,20 @@ namespace mlx::core {
namespace { namespace {
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst, Stream stream) { void copy_single(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
auto& encoder = cpu::get_command_encoder(stream); auto size = dst.size();
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
auto val = static_cast<DstT>(src_ptr[0]); auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val); std::fill_n(dst_ptr, size, val);
});
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst, Stream stream) { void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size(); auto size = src.data_size();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
std::copy(src_ptr, src_ptr + size, dst_ptr); std::copy(src_ptr, src_ptr + size, dst_ptr);
});
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT, int D>
@@ -66,7 +57,6 @@ template <typename SrcT, typename DstT>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
Stream stream,
const Shape& data_shape, const Shape& data_shape,
const Strides& i_strides, const Strides& i_strides,
const Strides& o_strides, const Strides& o_strides,
@@ -80,18 +70,7 @@ void copy_general_general(
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr; dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr = auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr; dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr,
dst_ptr,
size = src.size(),
data_shape = data_shape,
i_strides = i_strides,
o_strides = o_strides,
i_offset_ptr,
o_offset_ptr]() mutable {
if (data_shape.empty()) { if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr); auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val; *dst_ptr = val;
@@ -143,15 +122,13 @@ void copy_general_general(
in.step(); in.step();
out.step(); out.step();
} }
});
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst, Stream stream) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
src.shape(), src.shape(),
src.strides(), src.strides(),
dst.strides(), dst.strides(),
@@ -165,7 +142,6 @@ template <typename SrcT, typename DstT>
void copy_general( void copy_general(
const array& src, const array& src,
array& dst, array& dst,
Stream stream,
const Shape& data_shape, const Shape& data_shape,
const Strides& i_strides, const Strides& i_strides,
const Strides&, const Strides&,
@@ -176,7 +152,6 @@ void copy_general(
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
data_shape, data_shape,
i_strides, i_strides,
make_contiguous_strides(data_shape), make_contiguous_strides(data_shape),
@@ -187,11 +162,10 @@ void copy_general(
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst, Stream stream) { inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
src.shape(), src.shape(),
src.strides(), src.strides(),
make_contiguous_strides(src.shape()), make_contiguous_strides(src.shape()),
@@ -202,84 +176,67 @@ inline void copy_general(const array& src, array& dst, Stream stream) {
} }
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy( void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst, stream); copy_single<SrcT, DstT>(src, dst);
return; return;
case CopyType::Vector: case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst, stream); copy_vector<SrcT, DstT>(src, dst);
return; return;
case CopyType::General: case CopyType::General:
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...); copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
src, dst, stream, std::forward<Args>(args)...);
return; return;
} }
} }
template <typename SrcT, typename... Args> template <typename SrcT, typename... Args>
void copy( void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (dst.dtype()) { switch (dst.dtype()) {
case bool_: case bool_:
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<SrcT, uint16_t>( copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<SrcT, uint32_t>( copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<SrcT, uint64_t>( copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<SrcT, float16_t>( copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64: case float64:
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>( copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<SrcT, complex64_t>( copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
} }
} }
@@ -289,50 +246,49 @@ inline void copy_inplace_dispatch(
const array& src, const array& src,
array& dst, array& dst,
CopyType ctype, CopyType ctype,
Stream stream,
Args&&... args) { Args&&... args) {
switch (src.dtype()) { switch (src.dtype()) {
case bool_: case bool_:
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64: case float64:
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
} }
} }
@@ -340,7 +296,13 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
copy_inplace_dispatch(src, dst, ctype, stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy(const array& src, array& dst, CopyType ctype, Stream stream) { void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
@@ -368,6 +330,27 @@ void copy_inplace(
Stream stream, Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */ const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) { const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
if (x) {
return array::unsafe_weak_copy(*x);
} else {
return std::nullopt;
}
};
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) { switch (ctype) {
case CopyType::General: case CopyType::General:
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
@@ -375,7 +358,6 @@ void copy_inplace(
src, src,
dst, dst,
ctype, ctype,
stream,
data_shape, data_shape,
i_strides, i_strides,
o_strides, o_strides,
@@ -386,8 +368,9 @@ void copy_inplace(
break; break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::Vector: case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype, stream); copy_inplace_dispatch(src, dst, ctype);
} }
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -22,14 +22,47 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx; return idx;
} }
struct None {
template <typename T>
void operator()(T x, T* y) {
(*y) = x;
}
};
struct Sum {
template <typename T>
void operator()(T x, T* y) {
(*y) += x;
}
};
struct Prod {
template <typename T>
void operator()(T x, T* y) {
(*y) *= x;
}
};
struct Max {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y > x) ? *y : x;
}
};
struct Min {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y < x) ? *y : x;
}
};
template <typename T, typename IdxT> template <typename T, typename IdxT>
void gather( void gather(
const array& src, const array& src,
const std::vector<array>& inds, const std::vector<array>& inds,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
const Shape& slice_sizes, const Shape& slice_sizes) {
Stream stream) {
// If the array is row contiguous then we can do a contiguous copy given // If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size: // two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed // - Any number of leading ones in the slice sizes are allowed
@@ -82,43 +115,23 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
} }
std::vector<const IdxT*> ind_ptrs;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([src_ptr,
dst_ptr,
ind_ptrs = std::move(ind_ptrs),
axes,
ind_size,
slice_size,
src_shape = src.shape(),
src_strides = src.strides(),
src_it = std::move(src_it),
its = std::move(its),
can_copy]() mutable {
size_t out_idx = 0; size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < ind_ptrs.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = its[ii].loc; auto idx_loc = its[ii].loc;
its[ii].step(); its[ii].step();
auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]); auto idx_val =
src_idx += (idx_val * src_strides[ax]); offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
} }
if (slice_size == 1) { if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx]; dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) { } else if (can_copy) {
std::copy( std::copy(
src_ptr + src_idx, src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
src_ptr + src_idx + slice_size,
dst_ptr + out_idx);
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int jj = 0; jj < slice_size; jj++) {
@@ -128,7 +141,6 @@ void gather(
src_it.reset(); src_it.reset();
} }
} }
});
} }
template <typename IdxT> template <typename IdxT>
@@ -137,50 +149,49 @@ void dispatch_gather(
const std::vector<array>& inds, const std::vector<array>& inds,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
const Shape& size, const Shape& size) {
Stream stream) {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
gather<bool, IdxT>(src, inds, out, axes, size, stream); gather<bool, IdxT>(src, inds, out, axes, size);
break; break;
case uint8: case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size, stream); gather<uint8_t, IdxT>(src, inds, out, axes, size);
break; break;
case uint16: case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size, stream); gather<uint16_t, IdxT>(src, inds, out, axes, size);
break; break;
case uint32: case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size, stream); gather<uint32_t, IdxT>(src, inds, out, axes, size);
break; break;
case uint64: case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size, stream); gather<uint64_t, IdxT>(src, inds, out, axes, size);
break; break;
case int8: case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size, stream); gather<int8_t, IdxT>(src, inds, out, axes, size);
break; break;
case int16: case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size, stream); gather<int16_t, IdxT>(src, inds, out, axes, size);
break; break;
case int32: case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size, stream); gather<int32_t, IdxT>(src, inds, out, axes, size);
break; break;
case int64: case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size, stream); gather<int64_t, IdxT>(src, inds, out, axes, size);
break; break;
case float16: case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size, stream); gather<float16_t, IdxT>(src, inds, out, axes, size);
break; break;
case float32: case float32:
gather<float, IdxT>(src, inds, out, axes, size, stream); gather<float, IdxT>(src, inds, out, axes, size);
break; break;
case float64: case float64:
gather<double, IdxT>(src, inds, out, axes, size, stream); gather<double, IdxT>(src, inds, out, axes, size);
break; break;
case bfloat16: case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size, stream); gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break; break;
case complex64: case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size, stream); gather<complex64_t, IdxT>(src, inds, out, axes, size);
break; break;
} }
} }
@@ -189,51 +200,63 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end()); std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
inds.push_back(array::unsafe_weak_copy(*it));
}
auto& encoder = cpu::get_command_encoder(stream());
for (auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
slice_sizes_ = slice_sizes_,
src = array::unsafe_weak_copy(src),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
if (inds.empty()) { if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return; return;
} }
switch (inds[0].dtype()) { switch (inds[0].dtype()) {
case uint8: case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case uint16: case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case uint32: case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case uint64: case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int8: case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int16: case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int32: case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break; break;
case int64: case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_, stream()); dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type."); "[Gather::eval_cpu] Cannot gather with indices type.");
break; break;
} }
});
} }
template <typename T, typename IdxT> template <typename T, typename IdxT>
void gather_axis( void gather_axis(
const array& src, const array& src,
const array& ind, const array& ind,
array& out, array& out,
const int axis, const int axis) {
Stream stream) {
auto strides = ind.strides(); auto strides = ind.strides();
strides.erase(strides.begin() + axis); strides.erase(strides.begin() + axis);
auto shape = ind.shape(); auto shape = ind.shape();
@@ -262,23 +285,6 @@ void gather_axis(
size_post *= ind.shape(i); size_post *= ind.shape(i);
} }
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_input_array(ind);
encoder.set_output_array(out);
encoder.dispatch([ind_ptr,
src_ptr,
dst_ptr,
size_pre,
size_post,
ind_ax_size,
src_ax_size,
ind_ax_stride,
src_ax_stride,
dst_ax_stride,
ind_it = std::move(ind_it),
src_it = std::move(src_it)]() mutable {
size_t stride_pre = size_post * ind_ax_size; size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) { for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) { for (size_t k = 0; k < size_post; k++) {
@@ -293,7 +299,6 @@ void gather_axis(
} }
dst_ptr += stride_pre; dst_ptr += stride_pre;
} }
});
} }
template <typename IdxT> template <typename IdxT>
@@ -301,88 +306,97 @@ void dispatch_gather_axis(
const array& src, const array& src,
const array& inds, const array& inds,
array& out, array& out,
const int axis, const int axis) {
Stream stream) {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis, stream); gather_axis<bool, IdxT>(src, inds, out, axis);
break; break;
case uint8: case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis, stream); gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break; break;
case uint16: case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis, stream); gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break; break;
case uint32: case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis, stream); gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break; break;
case uint64: case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis, stream); gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break; break;
case int8: case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis, stream); gather_axis<int8_t, IdxT>(src, inds, out, axis);
break; break;
case int16: case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis, stream); gather_axis<int16_t, IdxT>(src, inds, out, axis);
break; break;
case int32: case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis, stream); gather_axis<int32_t, IdxT>(src, inds, out, axis);
break; break;
case int64: case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis, stream); gather_axis<int64_t, IdxT>(src, inds, out, axis);
break; break;
case float16: case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis, stream); gather_axis<float16_t, IdxT>(src, inds, out, axis);
break; break;
case float32: case float32:
gather_axis<float, IdxT>(src, inds, out, axis, stream); gather_axis<float, IdxT>(src, inds, out, axis);
break; break;
case float64: case float64:
gather_axis<double, IdxT>(src, inds, out, axis, stream); gather_axis<double, IdxT>(src, inds, out, axis);
break; break;
case bfloat16: case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis, stream); gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break; break;
case complex64: case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis, stream); gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break; break;
} }
} }
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
auto& inds = inputs[1]; auto& inds = inputs[1];
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(src);
encoder.set_input_array(inds);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
src = array::unsafe_weak_copy(src),
inds = array::unsafe_weak_copy(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (inds.dtype()) { switch (inds.dtype()) {
case uint8: case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break; break;
case uint16: case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break; break;
case uint32: case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break; break;
case uint64: case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break; break;
case int8: case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break; break;
case int16: case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break; break;
case int32: case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break; break;
case int64: case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_, stream()); dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type."); "[GatherAxis::eval_cpu] Cannot gather with indices type.");
break; break;
} }
});
} }
template <typename InT, typename IdxT, typename OpT> template <typename InT, typename IdxT, typename OpT>
@@ -390,9 +404,7 @@ void scatter(
const array& updates, const array& updates,
array& out, array& out,
const std::vector<array>& inds, const std::vector<array>& inds,
const std::vector<int>& axes, const std::vector<int>& axes) {
const OpT& op,
Stream stream) {
int nind = inds.size(); int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim(); auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1; size_t n_updates = nind ? inds[0].size() : 1;
@@ -408,45 +420,27 @@ void scatter(
ContiguousIterator update_it(updates); ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
std::vector<const IdxT*> ind_ptrs; auto out_ptr = out.data<InT>();
auto& encoder = cpu::get_command_encoder(stream); auto upd_ptr = updates.data<InT>();
encoder.set_input_array(updates);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<InT>(),
upd_ptr = updates.data<InT>(),
ind_ptrs = std::move(ind_ptrs),
axes,
n_updates,
update_size,
op = std::move(op),
out_shape = out.shape(),
out_strides = out.strides(),
out_it = std::move(out_it),
update_it = std::move(update_it),
its = std::move(its)]() mutable {
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < ind_ptrs.size(); ++j) { for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = its[j].loc; auto idx_loc = its[j].loc;
its[j].step(); its[j].step();
auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]); auto idx_val =
out_offset += (idx_val * out_strides[ax]); offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size); update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step(); update_it.step();
out_it.step(); out_it.step();
} }
out_it.reset(); out_it.reset();
update_it.reset(); update_it.reset();
} }
});
} }
template <typename InT, typename IdxT> template <typename InT, typename IdxT>
@@ -455,53 +449,22 @@ void dispatch_scatter_inds(
const std::vector<array>& indices, const std::vector<array>& indices,
const array& updates, const array& updates,
const std::vector<int>& axes, const std::vector<int>& axes,
Scatter::ReduceType rtype, Scatter::ReduceType rtype) {
Stream stream) {
switch (rtype) { switch (rtype) {
case Scatter::None: case Scatter::None:
scatter<InT, IdxT>( scatter<InT, IdxT, None>(updates, out, indices, axes);
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = x; },
stream);
break; break;
case Scatter::Sum: case Scatter::Sum:
scatter<InT, IdxT>( scatter<InT, IdxT, Sum>(updates, out, indices, axes);
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) += x; },
stream);
break; break;
case Scatter::Prod: case Scatter::Prod:
scatter<InT, IdxT>( scatter<InT, IdxT, Prod>(updates, out, indices, axes);
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) *= x; },
stream);
break; break;
case Scatter::Max: case Scatter::Max:
scatter<InT, IdxT>( scatter<InT, IdxT, Max>(updates, out, indices, axes);
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y > x) ? *y : x; },
stream);
break; break;
case Scatter::Min: case Scatter::Min:
scatter<InT, IdxT>( scatter<InT, IdxT, Min>(updates, out, indices, axes);
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y < x) ? *y : x; },
stream);
break; break;
} }
} }
@@ -512,46 +475,36 @@ void dispatch_scatter(
const std::vector<array>& inds, const std::vector<array>& inds,
const array& updates, const array& updates,
const std::vector<int>& axes, const std::vector<int>& axes,
Scatter::ReduceType rtype, Scatter::ReduceType rtype) {
Stream stream) {
if (inds.empty()) { if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>( dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
return; return;
} }
switch (inds[0].dtype()) { switch (inds[0].dtype()) {
case uint8: case uint8:
dispatch_scatter_inds<InT, uint8_t>( dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case uint16: case uint16:
dispatch_scatter_inds<InT, uint16_t>( dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case uint32: case uint32:
dispatch_scatter_inds<InT, uint32_t>( dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case uint64: case uint64:
dispatch_scatter_inds<InT, uint64_t>( dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case int8: case int8:
dispatch_scatter_inds<InT, int8_t>( dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case int16: case int16:
dispatch_scatter_inds<InT, int16_t>( dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case int32: case int32:
dispatch_scatter_inds<InT, int32_t>( dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
case int64: case int64:
dispatch_scatter_inds<InT, int64_t>( dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
out, inds, updates, axes, rtype, stream);
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
@@ -563,7 +516,6 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2); assert(inputs.size() >= 2);
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back(); auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out) // Copy src into out (copy allocates memory for out)
@@ -571,73 +523,68 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream()); copy(src, out, ctype, stream());
switch (src.dtype()) { auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
encoder.set_input_array(*it);
inds.push_back(array::unsafe_weak_copy(*it));
}
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
reduce_type_ = reduce_type_,
updates = array::unsafe_weak_copy(updates),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_: case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_, stream()); dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break; break;
case uint8: case uint8:
dispatch_scatter<uint8_t>( dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case uint16: case uint16:
dispatch_scatter<uint16_t>( dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case uint32: case uint32:
dispatch_scatter<uint32_t>( dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case uint64: case uint64:
dispatch_scatter<uint64_t>( dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case int8: case int8:
dispatch_scatter<int8_t>( dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case int16: case int16:
dispatch_scatter<int16_t>( dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case int32: case int32:
dispatch_scatter<int32_t>( dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case int64: case int64:
dispatch_scatter<int64_t>( dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case float16: case float16:
dispatch_scatter<float16_t>( dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case float32: case float32:
dispatch_scatter<float>( dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case float64: case float64:
dispatch_scatter<double>( dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case bfloat16: case bfloat16:
dispatch_scatter<bfloat16_t>( dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
case complex64: case complex64:
dispatch_scatter<complex64_t>( dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
out, inds, updates, axes_, reduce_type_, stream());
break; break;
} }
});
} }
template <typename T, typename IdxT, typename OpT> template <typename T, typename IdxT, typename OpT>
void scatter_axis( void scatter_axis(array& out, const array idx, const array& upd, int axis) {
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op,
Stream stream) {
auto strides = idx.strides(); auto strides = idx.strides();
strides.erase(strides.begin() + axis); strides.erase(strides.begin() + axis);
auto shape = idx.shape(); auto shape = idx.shape();
@@ -657,11 +604,6 @@ void scatter_axis(
auto idx_ax_size = idx.shape(axis); auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis); auto dst_ax_size = out.shape(axis);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(idx);
encoder.set_input_array(upd);
encoder.set_output_array(out);
size_t size_pre = 1; size_t size_pre = 1;
size_t size_post = 1; size_t size_post = 1;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
@@ -670,26 +612,14 @@ void scatter_axis(
for (int i = axis + 1; i < idx.ndim(); ++i) { for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i); size_post *= idx.shape(i);
} }
encoder.dispatch([idx_ptr,
upd_ptr,
dst_ptr,
size_pre,
size_post,
idx_ax_size,
dst_ax_size,
idx_ax_stride,
upd_ax_stride,
dst_ax_stride,
idx_it = std::move(idx_it),
upd_it = std::move(upd_it),
op = std::move(op)]() mutable {
size_t stride_pre = size_post * dst_ax_size; size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) { for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) { for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) { for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx( auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride], OpT{}(
upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride); dst_ptr + k + ind_val * dst_ax_stride);
} }
idx_it.step(); idx_it.step();
@@ -697,7 +627,6 @@ void scatter_axis(
} }
dst_ptr += stride_pre; dst_ptr += stride_pre;
} }
});
} }
template <typename InT, typename IdxT> template <typename InT, typename IdxT>
@@ -706,16 +635,13 @@ void dispatch_scatter_axis_op(
const array& idx, const array& idx,
const array& updates, const array& updates,
int axis, int axis,
ScatterAxis::ReduceType rtype, ScatterAxis::ReduceType rtype) {
Stream stream) {
switch (rtype) { switch (rtype) {
case ScatterAxis::None: case ScatterAxis::None:
scatter_axis<InT, IdxT>( scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream);
break; break;
case ScatterAxis::Sum: case ScatterAxis::Sum:
scatter_axis<InT, IdxT>( scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream);
break; break;
} }
} }
@@ -726,40 +652,31 @@ void dispatch_scatter_axis(
const array& idx, const array& idx,
const array& updates, const array& updates,
int axis, int axis,
ScatterAxis::ReduceType rtype, ScatterAxis::ReduceType rtype) {
Stream stream) {
switch (idx.dtype()) { switch (idx.dtype()) {
case uint8: case uint8:
dispatch_scatter_axis_op<InT, uint8_t>( dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case uint16: case uint16:
dispatch_scatter_axis_op<InT, uint16_t>( dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case uint32: case uint32:
dispatch_scatter_axis_op<InT, uint32_t>( dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case uint64: case uint64:
dispatch_scatter_axis_op<InT, uint64_t>( dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case int8: case int8:
dispatch_scatter_axis_op<InT, int8_t>( dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case int16: case int16:
dispatch_scatter_axis_op<InT, int16_t>( dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case int32: case int32:
dispatch_scatter_axis_op<InT, int32_t>( dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
case int64: case int64:
dispatch_scatter_axis_op<InT, int64_t>( dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
out, idx, updates, axis, rtype, stream);
break; break;
default: default:
throw std::runtime_error( throw std::runtime_error(
@@ -779,64 +696,63 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
src.flags().row_contiguous ? CopyType::Vector : CopyType::General; src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream()); copy(src, out, ctype, stream());
switch (src.dtype()) { auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
reduce_type_ = reduce_type_,
idx = array::unsafe_weak_copy(idx),
updates = array::unsafe_weak_copy(updates),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_: case bool_:
dispatch_scatter_axis<bool>( dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case uint8: case uint8:
dispatch_scatter_axis<uint8_t>( dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case uint16: case uint16:
dispatch_scatter_axis<uint16_t>( dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case uint32: case uint32:
dispatch_scatter_axis<uint32_t>( dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case uint64: case uint64:
dispatch_scatter_axis<uint64_t>( dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case int8: case int8:
dispatch_scatter_axis<int8_t>( dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case int16: case int16:
dispatch_scatter_axis<int16_t>( dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case int32: case int32:
dispatch_scatter_axis<int32_t>( dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case int64: case int64:
dispatch_scatter_axis<int64_t>( dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case float16: case float16:
dispatch_scatter_axis<float16_t>( dispatch_scatter_axis<float16_t>(
out, idx, updates, axis_, reduce_type_, stream()); out, idx, updates, axis_, reduce_type_);
break; break;
case float32: case float32:
dispatch_scatter_axis<float>( dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case float64: case float64:
dispatch_scatter_axis<double>( dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
out, idx, updates, axis_, reduce_type_, stream());
break; break;
case bfloat16: case bfloat16:
dispatch_scatter_axis<bfloat16_t>( dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_, stream()); out, idx, updates, axis_, reduce_type_);
break; break;
case complex64: case complex64:
dispatch_scatter_axis<complex64_t>( dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_, stream()); out, idx, updates, axis_, reduce_type_);
break; break;
} }
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
const array& biases, const array& biases,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1; int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1); int N = out.shape(-1);
@@ -335,48 +334,18 @@ void _qmm_dispatch_typed(
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M); int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>(); auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>(); auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>(); auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>(); auto biases_ptr = biases.data<T>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>( _qmm_dispatch_typed<T>(
out_ptr + i * M * N, out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides), x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides), w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides), scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides), biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
M, M,
N, N,
K, K,
@@ -384,7 +353,6 @@ void _qmm_dispatch_typed(
group_size, group_size,
transposed_w); transposed_w);
} }
});
} }
void _qmm_dispatch( void _qmm_dispatch(
@@ -395,20 +363,19 @@ void _qmm_dispatch(
const array& biases, const array& biases,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
_qmm_dispatch_typed<float>( _qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
case float16: case float16:
_qmm_dispatch_typed<float16_t>( _qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
case bfloat16: case bfloat16:
_qmm_dispatch_typed<bfloat16_t>( _qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
default: default:
throw std::invalid_argument( throw std::invalid_argument(
@@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
const array& rhs_indices, const array& rhs_indices,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.shape(-2); int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
@@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
int w_els = w.shape(-1) * w.shape(-2); int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>(); auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>(); auto w_ptr = w.data<uint32_t>();
@@ -453,45 +410,19 @@ void _bs_qmm_dispatch_typed(
auto lhs_indices_ptr = lhs_indices.data<uint32_t>(); auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>(); auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.dispatch([out_ptr, for (int i = 0; i < lhs_indices.size(); i++) {
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
lhs_indices_ptr,
rhs_indices_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
w_els,
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
int x_idx = lhs_indices_ptr[elem_to_loc( int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)]; i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc( int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)]; i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>( _qmm_dispatch_typed<T>(
out_ptr + i * M * N, out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides), x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides), w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides), scales_ptr +
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides), elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M, M,
N, N,
K, K,
@@ -499,7 +430,6 @@ void _bs_qmm_dispatch_typed(
group_size, group_size,
transposed_w); transposed_w);
} }
});
} }
void _bs_qmm_dispatch( void _bs_qmm_dispatch(
@@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
const array& rhs_indices, const array& rhs_indices,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
_bs_qmm_dispatch_typed<float>( _bs_qmm_dispatch_typed<float>(
@@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
case float16: case float16:
_bs_qmm_dispatch_typed<float16_t>( _bs_qmm_dispatch_typed<float16_t>(
@@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
case bfloat16: case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>( _bs_qmm_dispatch_typed<bfloat16_t>(
@@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
default: default:
throw std::invalid_argument( throw std::invalid_argument(
@@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream()); auto& encoder = cpu::get_command_encoder(stream());
auto& enc = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps));
enc.add_temporaries(std::move(temps)); encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -626,6 +566,26 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch( _bs_qmm_dispatch(
out, out,
x, x,
@@ -636,10 +596,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
rhs_indices, rhs_indices,
group_size_, group_size_,
bits_, bits_,
transpose_, transpose_);
stream()); });
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
} }
template <typename T, typename U> template <typename T, typename U>
@@ -709,27 +667,13 @@ void dispatch_quantize(
array& scales, array& scales,
array& biases, array& biases,
int bits, int bits,
int group_size, int group_size) {
Stream stream) {
auto w_ptr = w.data<T>(); auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>(); auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>(); auto biases_ptr = biases.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>( quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
});
} }
void fast::AffineQuantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
@@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
auto& biases = outputs[2]; auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes())); scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes())); biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);
}
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) { if (w.dtype() == float16) {
if (is_power_of_2(bits_)) { if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>( dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} else { } else {
dispatch_quantize<float16_t, uint8_t>( dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} }
} else if (w.dtype() == bfloat16) { } else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) { if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>( dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} else { } else {
dispatch_quantize<bfloat16_t, uint8_t>( dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} }
} else if (w.dtype() == float32) { } else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) { if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>( dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} else { } else {
dispatch_quantize<float, uint8_t>( dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream()); w, out, scales, biases, bits_, group_size_);
} }
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
} }
if (copied) { });
cpu::get_command_encoder(stream()).add_temporary(w);
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -140,34 +140,23 @@ void reduction_op(
const array& x, const array& x,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
U init, U init) {
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes); ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>(); auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init; *out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init); contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
});
return; return;
} }
if (plan.type == ContiguousReduce && plan.shape.size() == 1) { if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0]; int reduction_size = plan.shape[0];
encoder.dispatch( for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init; *out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
} }
});
return; return;
} }
@@ -178,24 +167,14 @@ void reduction_op(
// Unrolling the following loop (and implementing it in order for // Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost. // ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) { if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
*out_ptr = init; *out_ptr = init;
contiguous_reduce( contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
} }
} else { } else {
for (int i = 0; i < size; i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
*out_ptr = init; *out_ptr = init;
nd_loop( nd_loop(
@@ -211,7 +190,6 @@ void reduction_op(
plan.strides); plan.strides);
} }
} }
});
return; return;
} }
@@ -220,20 +198,12 @@ void reduction_op(
size_t reduction_stride = plan.strides.back(); size_t reduction_stride = plan.strides.back();
plan.shape.pop_back(); plan.shape.pop_back();
plan.strides.pop_back(); plan.strides.pop_back();
for (int i = 0; i < out.size(); i += reduction_stride) {
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init); std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size; in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride; out_ptr += reduction_stride;
} }
});
return; return;
} }
@@ -245,17 +215,8 @@ void reduction_op(
plan.strides.pop_back(); plan.strides.pop_back();
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) { if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) { for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init); std::fill_n(out_ptr, reduction_stride, init);
strided_reduce( strided_reduce(
@@ -263,7 +224,7 @@ void reduction_op(
out_ptr += reduction_stride; out_ptr += reduction_stride;
} }
} else { } else {
for (int i = 0; i < size; i += reduction_stride) { for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init); std::fill_n(out_ptr, reduction_stride, init);
nd_loop( nd_loop(
@@ -280,21 +241,13 @@ void reduction_op(
out_ptr += reduction_stride; out_ptr += reduction_stride;
} }
} }
});
return; return;
} }
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr, for (int i = 0; i < out.size(); i++, out_ptr++) {
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
U val = init; U val = init;
nd_loop( nd_loop(
@@ -305,7 +258,6 @@ void reduction_op(
plan.strides); plan.strides);
*out_ptr = val; *out_ptr = val;
} }
});
} }
} }
@@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::And) { if (rtype == Reduce::And) {
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream); reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else { } else {
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream); reduction_op<InT, bool, OrReduce>(in, out, axes, false);
} }
} }
@@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::Sum) { if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream); reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else { } else {
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream); reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
} }
} else { } else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream); reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else { } else {
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream); reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
} }
} }
} }
@@ -470,20 +420,27 @@ void reduce_dispatch_min_max(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::Max) { if (rtype == Reduce::Max) {
auto init = Limits<InT>::min; auto init = Limits<InT>::min;
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream); reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else { } else {
auto init = Limits<InT>::max; auto init = Limits<InT>::max;
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream); reduction_op<InT, InT, MinReduce>(in, out, axes, init);
} }
} }
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axes_ = axes_]() mutable {
switch (reduce_type_) { switch (reduce_type_) {
case Reduce::And: case Reduce::And:
case Reduce::Or: { case Reduce::Or: {
@@ -491,28 +448,24 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_: case bool_:
case uint8: case uint8:
case int8: case int8:
reduce_dispatch_and_or<int8_t>( reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int16: case int16:
case uint16: case uint16:
case float16: case float16:
case bfloat16: case bfloat16:
reduce_dispatch_and_or<int16_t>( reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case uint32: case uint32:
case int32: case int32:
case float32: case float32:
reduce_dispatch_and_or<int32_t>( reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case uint64: case uint64:
case int64: case int64:
case float64: case float64:
case complex64: case complex64:
reduce_dispatch_and_or<int64_t>( reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
} }
break; break;
@@ -523,43 +476,34 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_: case bool_:
case uint8: case uint8:
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>( reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int16: case int16:
case uint16: case uint16:
reduce_dispatch_sum_prod<int16_t>( reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int32: case int32:
case uint32: case uint32:
reduce_dispatch_sum_prod<int32_t>( reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int64: case int64:
case uint64: case uint64:
reduce_dispatch_sum_prod<int64_t>( reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float16: case float16:
reduce_dispatch_sum_prod<float16_t>( reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case bfloat16: case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>( reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float32: case float32:
reduce_dispatch_sum_prod<float>( reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float64: case float64:
reduce_dispatch_sum_prod<double>( reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case complex64: case complex64:
reduce_dispatch_sum_prod<complex64_t>( reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
} }
break; break;
@@ -568,64 +512,52 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case Reduce::Min: { case Reduce::Min: {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream()); reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break; break;
case uint8: case uint8:
reduce_dispatch_min_max<uint8_t>( reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case uint16: case uint16:
reduce_dispatch_min_max<uint16_t>( reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case uint32: case uint32:
reduce_dispatch_min_max<uint32_t>( reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case uint64: case uint64:
reduce_dispatch_min_max<uint64_t>( reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int8: case int8:
reduce_dispatch_min_max<uint8_t>( reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int16: case int16:
reduce_dispatch_min_max<uint16_t>( reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int32: case int32:
reduce_dispatch_min_max<int32_t>( reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case int64: case int64:
reduce_dispatch_min_max<int64_t>( reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float16: case float16:
reduce_dispatch_min_max<float16_t>( reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float32: case float32:
reduce_dispatch_min_max<float>( reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case float64: case float64:
reduce_dispatch_min_max<double>( reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case bfloat16: case bfloat16:
reduce_dispatch_min_max<bfloat16_t>( reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
case complex64: case complex64:
reduce_dispatch_min_max<complex64_t>( reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream());
break; break;
} }
break; break;
} }
} }
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -160,38 +160,29 @@ void scan_op(
bool reverse, bool reverse,
bool inclusive, bool inclusive,
const Op& op, const Op& op,
U init, U init) {
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) { if (in.strides()[axis] == 1) {
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis),
stride = in.shape(axis),
reverse,
inclusive,
op = std::move(op),
init]() {
contiguous_scan( contiguous_scan(
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init); in.data<T>(),
}); out.data<U>(),
} else { in.size() / in.shape(axis),
encoder.dispatch([in_ptr = in.data<T>(), in.shape(axis),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis) / in.strides()[axis],
size = in.shape(axis),
stride = in.strides()[axis],
reverse, reverse,
inclusive, inclusive,
op = std::move(op), op,
init]() { init);
} else {
strided_scan( strided_scan(
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init); in.data<T>(),
}); out.data<U>(),
in.size() / in.shape(axis) / in.strides()[axis],
in.shape(axis),
in.strides()[axis],
reverse,
inclusive,
op,
init);
} }
} else { } else {
throw std::runtime_error("Scan op supports only contiguous inputs"); throw std::runtime_error("Scan op supports only contiguous inputs");
@@ -205,19 +196,18 @@ void scan_dispatch(
array& out, array& out,
int axis, int axis,
bool reverse, bool reverse,
bool inclusive, bool inclusive) {
Stream stream) {
switch (rtype) { switch (rtype) {
case Scan::Sum: { case Scan::Sum: {
auto op = [](U y, T x) { return y + x; }; auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0); auto init = static_cast<U>(0);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Prod: { case Scan::Prod: {
auto op = [](U y, T x) { return y * x; }; auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1); auto init = static_cast<U>(1);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Min: { case Scan::Min: {
@@ -225,7 +215,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity()) ? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Max: { case Scan::Max: {
@@ -233,7 +223,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity()) ? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min(); : std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
} }
@@ -244,17 +234,26 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) { void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
bool copied = false;
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream()); copy(in, arr_copy, CopyType::General, stream());
in = arr_copy; in = arr_copy;
copied = true; encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
reduce_type_ = reduce_type_,
reverse_ = reverse_,
inclusive_ = inclusive_]() mutable {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: { case bool_: {
// We could do a full dtype x dtype switch but this is the only case // We could do a full dtype x dtype switch but this is the only case
@@ -264,68 +263,66 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// floats perhaps we should add the full all-to-all dispatch. // floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) { if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>( scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
} else { } else {
scan_dispatch<bool, bool>( scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
} }
break; break;
} }
case uint8: case uint8:
scan_dispatch<uint8_t, uint8_t>( scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case uint16: case uint16:
scan_dispatch<uint16_t, uint16_t>( scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case uint32: case uint32:
scan_dispatch<uint32_t, uint32_t>( scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case uint64: case uint64:
scan_dispatch<uint64_t, uint64_t>( scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case int8: case int8:
scan_dispatch<int8_t, int8_t>( scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case int16: case int16:
scan_dispatch<int16_t, int16_t>( scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case int32: case int32:
scan_dispatch<int32_t, int32_t>( scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case int64: case int64:
scan_dispatch<int64_t, int64_t>( scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case float16: case float16:
scan_dispatch<float16_t, float16_t>( scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case float32: case float32:
scan_dispatch<float, float>( scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case float64: case float64:
scan_dispatch<double, double>( scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case bfloat16: case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>( scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case complex64: case complex64:
throw std::runtime_error("Scan ops do not support complex types yet"); throw std::runtime_error("Scan ops do not support complex types yet");
break; break;
} }
if (copied) { });
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -16,51 +16,70 @@ void select_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
Stream stream) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
c = array::unsafe_weak_copy(c),
out = array::unsafe_weak_copy(out),
op,
topt]() mutable {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op); ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break; break;
case uint8: case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op); ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break; break;
case uint16: case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op); ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break; break;
case uint32: case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op); ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break; break;
case uint64: case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op); ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break; break;
case int8: case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op); ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break; break;
case int16: case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op); ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break; break;
case int32: case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op); ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break; break;
case int64: case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op); ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break; break;
case float16: case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op); ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break; break;
case float32: case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op); ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break; break;
case float64: case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op); ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break; break;
case bfloat16: case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op); ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break; break;
case complex64: case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op); ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break; break;
} }
});
} }
} // namespace } // namespace
@@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0]; const auto& condition = inputs[0];
const auto& a = inputs[1]; const auto& a = inputs[1];
const auto& b = inputs[2]; const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select()); select_op(condition, a, b, out, detail::Select(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -105,15 +105,11 @@ struct StridedIterator {
}; };
template <typename T> template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) { void sort(array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -127,13 +123,7 @@ void sort(const array& in, array& out, int axis, Stream stream) {
// Perform sorting in place // Perform sorting in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto out_ptr = out.data<T>();
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
@@ -143,14 +133,10 @@ void sort(const array& in, array& out, int axis, Stream stream) {
std::stable_sort(st, ed); std::stable_sort(st, ed);
src_it.step(); src_it.step();
} }
});
} }
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis, Stream stream) { void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@@ -176,17 +162,8 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto in_ptr = in.data<T>();
encoder.set_input_array(in); auto out_ptr = out.data<IdxT>();
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -210,42 +187,30 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
});
} }
template <typename T> template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) { void partition(array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis]; auto axis_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto out_ptr = out.data<T>();
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
src_it.step(); src_it.step();
@@ -256,19 +221,10 @@ void partition(const array& in, array& out, int axis, int kth, Stream stream) {
std::nth_element(st, md, ed); std::nth_element(st, md, ed);
} }
});
} }
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argpartition( void argpartition(const array& in, array& out, int axis, int kth) {
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@@ -297,18 +253,9 @@ void argpartition(
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto in_ptr = in.data<T>();
encoder.set_input_array(in); auto out_ptr = out.data<IdxT>();
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -332,7 +279,6 @@ void argpartition(
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
});
} }
} // namespace } // namespace
@@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_]() mutable {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
return argsort<bool>(in, out, axis_, stream()); return argsort<bool>(in, out, axis_);
case uint8: case uint8:
return argsort<uint8_t>(in, out, axis_, stream()); return argsort<uint8_t>(in, out, axis_);
case uint16: case uint16:
return argsort<uint16_t>(in, out, axis_, stream()); return argsort<uint16_t>(in, out, axis_);
case uint32: case uint32:
return argsort<uint32_t>(in, out, axis_, stream()); return argsort<uint32_t>(in, out, axis_);
case uint64: case uint64:
return argsort<uint64_t>(in, out, axis_, stream()); return argsort<uint64_t>(in, out, axis_);
case int8: case int8:
return argsort<int8_t>(in, out, axis_, stream()); return argsort<int8_t>(in, out, axis_);
case int16: case int16:
return argsort<int16_t>(in, out, axis_, stream()); return argsort<int16_t>(in, out, axis_);
case int32: case int32:
return argsort<int32_t>(in, out, axis_, stream()); return argsort<int32_t>(in, out, axis_);
case int64: case int64:
return argsort<int64_t>(in, out, axis_, stream()); return argsort<int64_t>(in, out, axis_);
case float32: case float32:
return argsort<float>(in, out, axis_, stream()); return argsort<float>(in, out, axis_);
case float64: case float64:
return argsort<double>(in, out, axis_, stream()); return argsort<double>(in, out, axis_);
case float16: case float16:
return argsort<float16_t>(in, out, axis_, stream()); return argsort<float16_t>(in, out, axis_);
case bfloat16: case bfloat16:
return argsort<bfloat16_t>(in, out, axis_, stream()); return argsort<bfloat16_t>(in, out, axis_);
case complex64: case complex64:
return argsort<complex64_t>(in, out, axis_, stream()); return argsort<complex64_t>(in, out, axis_);
} }
});
} }
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) { void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_: case bool_:
return sort<bool>(in, out, axis_, stream()); return sort<bool>(out, axis_);
case uint8: case uint8:
return sort<uint8_t>(in, out, axis_, stream()); return sort<uint8_t>(out, axis_);
case uint16: case uint16:
return sort<uint16_t>(in, out, axis_, stream()); return sort<uint16_t>(out, axis_);
case uint32: case uint32:
return sort<uint32_t>(in, out, axis_, stream()); return sort<uint32_t>(out, axis_);
case uint64: case uint64:
return sort<uint64_t>(in, out, axis_, stream()); return sort<uint64_t>(out, axis_);
case int8: case int8:
return sort<int8_t>(in, out, axis_, stream()); return sort<int8_t>(out, axis_);
case int16: case int16:
return sort<int16_t>(in, out, axis_, stream()); return sort<int16_t>(out, axis_);
case int32: case int32:
return sort<int32_t>(in, out, axis_, stream()); return sort<int32_t>(out, axis_);
case int64: case int64:
return sort<int64_t>(in, out, axis_, stream()); return sort<int64_t>(out, axis_);
case float32: case float32:
return sort<float>(in, out, axis_, stream()); return sort<float>(out, axis_);
case float64: case float64:
return sort<double>(in, out, axis_, stream()); return sort<double>(out, axis_);
case float16: case float16:
return sort<float16_t>(in, out, axis_, stream()); return sort<float16_t>(out, axis_);
case bfloat16: case bfloat16:
return sort<bfloat16_t>(in, out, axis_, stream()); return sort<bfloat16_t>(out, axis_);
case complex64: case complex64:
return sort<complex64_t>(in, out, axis_, stream()); return sort<complex64_t>(out, axis_);
} }
});
} }
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
return argpartition<bool>(in, out, axis_, kth_, stream()); return argpartition<bool>(in, out, axis_, kth_);
case uint8: case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_, stream()); return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16: case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_, stream()); return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32: case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_, stream()); return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64: case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_, stream()); return argpartition<uint64_t>(in, out, axis_, kth_);
case int8: case int8:
return argpartition<int8_t>(in, out, axis_, kth_, stream()); return argpartition<int8_t>(in, out, axis_, kth_);
case int16: case int16:
return argpartition<int16_t>(in, out, axis_, kth_, stream()); return argpartition<int16_t>(in, out, axis_, kth_);
case int32: case int32:
return argpartition<int32_t>(in, out, axis_, kth_, stream()); return argpartition<int32_t>(in, out, axis_, kth_);
case int64: case int64:
return argpartition<int64_t>(in, out, axis_, kth_, stream()); return argpartition<int64_t>(in, out, axis_, kth_);
case float32: case float32:
return argpartition<float>(in, out, axis_, kth_, stream()); return argpartition<float>(in, out, axis_, kth_);
case float64: case float64:
return argpartition<double>(in, out, axis_, kth_, stream()); return argpartition<double>(in, out, axis_, kth_);
case float16: case float16:
return argpartition<float16_t>(in, out, axis_, kth_, stream()); return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16: case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream()); return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64: case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_, stream()); return argpartition<complex64_t>(in, out, axis_, kth_);
} }
});
} }
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) { void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (out.dtype()) {
case bool_: case bool_:
return partition<bool>(in, out, axis_, kth_, stream()); return partition<bool>(out, axis_, kth_);
case uint8: case uint8:
return partition<uint8_t>(in, out, axis_, kth_, stream()); return partition<uint8_t>(out, axis_, kth_);
case uint16: case uint16:
return partition<uint16_t>(in, out, axis_, kth_, stream()); return partition<uint16_t>(out, axis_, kth_);
case uint32: case uint32:
return partition<uint32_t>(in, out, axis_, kth_, stream()); return partition<uint32_t>(out, axis_, kth_);
case uint64: case uint64:
return partition<uint64_t>(in, out, axis_, kth_, stream()); return partition<uint64_t>(out, axis_, kth_);
case int8: case int8:
return partition<int8_t>(in, out, axis_, kth_, stream()); return partition<int8_t>(out, axis_, kth_);
case int16: case int16:
return partition<int16_t>(in, out, axis_, kth_, stream()); return partition<int16_t>(out, axis_, kth_);
case int32: case int32:
return partition<int32_t>(in, out, axis_, kth_, stream()); return partition<int32_t>(out, axis_, kth_);
case int64: case int64:
return partition<int64_t>(in, out, axis_, kth_, stream()); return partition<int64_t>(out, axis_, kth_);
case float32: case float32:
return partition<float>(in, out, axis_, kth_, stream()); return partition<float>(out, axis_, kth_);
case float64: case float64:
return partition<double>(in, out, axis_, kth_, stream()); return partition<double>(out, axis_, kth_);
case float16: case float16:
return partition<float16_t>(in, out, axis_, kth_, stream()); return partition<float16_t>(out, axis_, kth_);
case bfloat16: case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_, stream()); return partition<bfloat16_t>(out, axis_, kth_);
case complex64: case complex64:
return partition<complex64_t>(in, out, axis_, kth_, stream()); return partition<complex64_t>(out, axis_, kth_);
} }
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,12 +1,10 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -128,57 +126,28 @@ void ternary_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
TernaryOpType topt = get_ternary_op_type(a, b, c); TernaryOpType topt) {
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
const T1* a_ptr = a.data<T1>(); const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>(); const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>(); const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
encoder.dispatch(
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
});
} else if (topt == TernaryOpType::VectorVectorVector) { } else if (topt == TernaryOpType::VectorVectorVector) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < out.size(); ++i) {
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size()]() mutable {
for (size_t i = 0; i < size; ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++; a_ptr++;
b_ptr++; b_ptr++;
c_ptr++; c_ptr++;
out_ptr++; out_ptr++;
} }
});
} else { } else {
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
encoder.dispatch(
[a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size(),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
ternary_op_dispatch_dims<T1, T2, T3, U>( ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides); a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
});
} }
} }

View File

@@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto op = detail::Abs{}; unary_signed(in, out, detail::Abs(), stream());
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
} }
} }
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos()); unary_fp(in, out, detail::ArcCos(), stream());
} }
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh()); unary_fp(in, out, detail::ArcCosh(), stream());
} }
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin()); unary_fp(in, out, detail::ArcSin(), stream());
} }
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh()); unary_fp(in, out, detail::ArcSinh(), stream());
} }
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan()); unary_fp(in, out, detail::ArcTan(), stream());
} }
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh()); unary_fp(in, out, detail::ArcTanh(), stream());
} }
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert()); unary_int(in, out, detail::BitwiseInvert(), stream());
} }
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) { void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil()); unary_fp(in, out, detail::Ceil(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) { void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate()); unary_complex(inputs[0], out, detail::Conjugate(), stream());
} }
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) { void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cos()); unary_fp(in, out, detail::Cos(), stream());
} }
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) { void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh()); unary_fp(in, out, detail::Cosh(), stream());
} }
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) { void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::Erf(), stream());
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
} }
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) { void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::ErfInv(), stream());
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
} }
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) { void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Exp()); unary_fp(in, out, detail::Exp(), stream());
} }
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) { void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1()); unary_fp(in, out, detail::Expm1(), stream());
} }
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) { void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor()); unary_fp(in, out, detail::Floor(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) { void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag()); unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
} }
void Log::eval_cpu(const std::vector<array>& inputs, array& out) { void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_fp(in, out, detail::Log()); unary_fp(in, out, detail::Log(), stream());
break; break;
case Base::two: case Base::two:
unary_fp(in, out, detail::Log2()); unary_fp(in, out, detail::Log2(), stream());
break; break;
case Base::ten: case Base::ten:
unary_fp(in, out, detail::Log10()); unary_fp(in, out, detail::Log10(), stream());
break; break;
} }
} }
@@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) { void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p()); unary_fp(in, out, detail::Log1p(), stream());
} }
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::LogicalNot()); unary(in, out, detail::LogicalNot(), stream());
} }
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) { void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Negative()); unary(in, out, detail::Negative(), stream());
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) { void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real()); unary_complex_to_float(inputs[0], out, detail::Real(), stream());
} }
void Round::eval_cpu(const std::vector<array>& inputs, array& out) { void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round()); unary_fp(in, out, detail::Round(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) { void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid()); unary_fp(in, out, detail::Sigmoid(), stream());
} }
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) { void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
unary(in, out, detail::Sign()); unary(in, out, detail::Sign(), stream());
} }
} }
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) { void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sin()); unary_fp(in, out, detail::Sin(), stream());
} }
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) { void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh()); unary_fp(in, out, detail::Sinh(), stream());
} }
void Square::eval_cpu(const std::vector<array>& inputs, array& out) { void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Square()); unary(in, out, detail::Square(), stream());
} }
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) { void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (recip_) { if (recip_) {
unary_fp(in, out, detail::Rsqrt()); unary_fp(in, out, detail::Rsqrt(), stream());
} else { } else {
unary_fp(in, out, detail::Sqrt()); unary_fp(in, out, detail::Sqrt(), stream());
} }
} }
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) { void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tan()); unary_fp(in, out, detail::Tan(), stream());
} }
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) { void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh()); unary_fp(in, out, detail::Tanh(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -7,7 +7,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -39,53 +38,48 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op) { void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>(); const T* src = a.data<T>();
U* dst = out.data<U>(); U* dst = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream()); auto ndim = a.ndim();
encoder.set_input_array(a); if (a.flags().contiguous) {
encoder.set_output_array(out); auto size = a.data_size();
encoder.dispatch([src,
dst,
contig = a.flags().contiguous,
data_size = a.data_size(),
size = a.size(),
shapes = a.shape(),
strides = a.strides()]() mutable {
auto ndim = shapes.size();
if (contig) {
constexpr int N = simd::max_size<T>; constexpr int N = simd::max_size<T>;
while (data_size >= N) { while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src))); simd::store(dst, Op{}(simd::load<T, N>(src)));
data_size -= N; size -= N;
src += N; src += N;
dst += N; dst += N;
} }
while (data_size > 0) { while (size > 0) {
*dst = Op{}(*src); *dst = Op{}(*src);
data_size--; size--;
dst++; dst++;
src++; src++;
} }
} else { } else {
size_t shape = ndim > 0 ? shapes.back() : 1; size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? strides.back() : 1; size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) { if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride); unary_op<T, U, Op>(src, dst, shape, stride);
return; return;
} }
auto it = ContiguousIterator(shapes, strides, ndim - 1); auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < size; elem += shape) { for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride); unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step(); it.step();
} }
} }
});
} }
template <typename Op> template <typename Op>
void unary(const array& a, array& out, Op op) { void unary(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
unary_op<bool>(a, out, op); unary_op<bool>(a, out, op);
@@ -130,10 +124,47 @@ void unary(const array& a, array& out, Op op) {
unary_op<complex64_t>(a, out, op); unary_op<complex64_t>(a, out, op);
break; break;
} }
});
} }
template <typename Op> template <typename Op>
void unary_fp(const array& a, array& out, Op op) { void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_real] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) { switch (out.dtype()) {
case bfloat16: case bfloat16:
unary_op<bfloat16_t>(a, out, op); unary_op<bfloat16_t>(a, out, op);
@@ -155,10 +186,84 @@ void unary_fp(const array& a, array& out, Op op) {
err << "[unary_fp] Does not support " << out.dtype(); err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
});
} }
template <typename Op> template <typename Op>
void unary_int(const array& a, array& out, Op op) { void unary_signed(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
}
template <typename Op>
void unary_complex(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
}
template <typename Op>
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch(
[a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
}
template <typename Op>
void unary_int(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) { switch (out.dtype()) {
case uint8: case uint8:
unary_op<uint8_t>(a, out, op); unary_op<uint8_t>(a, out, op);
@@ -189,6 +294,7 @@ void unary_int(const array& a, array& out, Op op) {
err << "[unary_int] Does not support " << out.dtype(); err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
});
} }
} // namespace mlx::core } // namespace mlx::core