reduce binary size (#1952)

This commit is contained in:
Awni Hannun 2025-03-11 06:30:44 -07:00 committed by GitHub
parent 117e1355a2
commit 736a340478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2145 additions and 2386 deletions

View File

@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs;
}
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},

View File

@ -199,6 +199,13 @@ class array {
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */
std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());

View File

@ -11,12 +11,7 @@ namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(
const array& in,
array& out,
const OpT& op,
int axis,
Stream stream) {
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
@ -26,18 +21,7 @@ void arg_reduce(
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in_ptr,
out_ptr,
axis_size,
axis_stride,
op = std::move(op),
shape = std::move(shape),
strides = std::move(strides),
size = out.size()]() {
for (uint32_t i = 0; i < size; ++i) {
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
@ -47,7 +31,6 @@ void arg_reduce(
}
out_ptr[i] = ind_v;
}
});
}
template <typename InT>
@ -55,8 +38,7 @@ void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis,
Stream stream) {
int axis) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
@ -65,7 +47,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis, stream);
arg_reduce<InT>(in, out, op, axis);
break;
}
case ArgReduce::ArgMax: {
@ -75,7 +57,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis, stream);
arg_reduce<InT>(in, out, op, axis);
break;
}
}
@ -87,51 +69,58 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
}
} // namespace mlx::core

View File

@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@ -16,51 +17,218 @@ namespace mlx::core {
namespace {
template <typename Op>
void comparison_op(const array& a, const array& b, array& out) {
switch (a.dtype()) {
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out);
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out);
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out);
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out);
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out);
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out);
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out);
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out);
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out);
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out);
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out);
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out);
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out);
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out);
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@ -78,70 +246,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out_a = array::unsafe_weak_copy(out_a),
out_b = array::unsafe_weak_copy(out_b),
bopt]() mutable {
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
switch (out_a.dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out);
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out);
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
} else {
comparison_op<detail::Equal>(a, b, out);
comparison_op(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op<detail::Less>(inputs[0], inputs[1], out);
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t, detail::LogAddExp>(a, b, out);
break;
case float32:
binary_op<float, detail::LogAddExp>(a, b, out);
break;
case float64:
binary_op<double, detail::LogAddExp>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
binary(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
binary(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
binary(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
binary_int(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
binary_int(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
binary_int(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
binary_int(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
binary_int(a, b, out, detail::RightShift(), stream());
break;
}
}
@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@ -3,12 +3,9 @@
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/backend/cpu/simd/simd.h"
@ -152,30 +149,12 @@ void binary_op_dispatch_dims(
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([bopt,
a_ptr,
b_ptr,
out_ptr,
a_data_size = a.data_size(),
b_data_size = b.data_size(),
size = a.size(),
shape = a.shape(),
a_strides = a.strides(),
b_strides = b.strides(),
strides = out.strides()]() mutable {
if (bopt == BinaryOpType::ScalarScalar) {
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
@ -183,29 +162,28 @@ void binary_op(const array& a, const array& b, array& out) {
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
shape,
{std::move(a_strides), std::move(b_strides), std::move(strides)});
a_strides = new_strides[0];
b_strides = new_strides[1];
strides = new_strides[2];
a.shape(), {a.strides(), b.strides(), out.strides()});
auto& a_strides = new_strides[0];
auto& b_strides = new_strides[1];
auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
@ -262,7 +240,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr,
out_ptr,
dim,
size,
a.size(),
new_shape,
a_strides,
b_strides,
@ -274,7 +252,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr,
out_ptr,
dim,
size,
a.size(),
new_shape,
a_strides,
b_strides,
@ -286,7 +264,7 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr,
out_ptr,
dim,
size,
a.size(),
new_shape,
a_strides,
b_strides,
@ -298,72 +276,18 @@ void binary_op(const array& a, const array& b, array& out) {
b_ptr,
out_ptr,
dim,
size,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
}
});
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out) {
binary_op<T, T, Op>(a, b, out);
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T, Op>(a, b, out);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out);
break;
case int8:
binary_op<int8_t, Op>(a, b, out);
break;
case int16:
binary_op<int16_t, Op>(a, b, out);
break;
case int32:
binary_op<int32_t, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, Op>(a, b, out);
break;
case float32:
binary_op<float, Op>(a, b, out);
break;
case float64:
binary_op<double, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out);
break;
}
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
} // namespace mlx::core

View File

@ -4,8 +4,6 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -57,14 +55,7 @@ void binary_op_dispatch_dims(
const array& b,
array& out_a,
array& out_b,
Stream stream,
Op op) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const T* a_ptr = a.data<T>();
@ -72,14 +63,6 @@ void binary_op_dispatch_dims(
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
shape = std::move(shape),
strides = std::move(strides),
op = std::move(op)]() {
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
@ -116,7 +99,7 @@ void binary_op_dispatch_dims(
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < size; elem += stride) {
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@ -131,138 +114,50 @@ void binary_op_dispatch_dims(
a_it.step();
b_it.step();
}
});
}
template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto stream = out_a.primitive().stream();
array& out_a,
array& out_b,
Op op,
BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
encoder.dispatch(
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
});
} else if (bopt == BinaryOpType::ScalarVector) {
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = b.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
});
} else if (bopt == BinaryOpType::VectorScalar) {
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
});
} else { // VectorVector
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
});
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}

View File

@ -13,29 +13,20 @@ namespace mlx::core {
namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst, Stream stream) {
void copy_single(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
auto size = dst.size();
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
});
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst, Stream stream) {
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
auto size = src.data_size();
std::copy(src_ptr, src_ptr + size, dst_ptr);
});
}
template <typename SrcT, typename DstT, int D>
@ -66,7 +57,6 @@ template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
@ -80,18 +70,7 @@ void copy_general_general(
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr,
dst_ptr,
size = src.size(),
data_shape = data_shape,
i_strides = i_strides,
o_strides = o_strides,
i_offset_ptr,
o_offset_ptr]() mutable {
auto size = src.size();
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
@ -143,15 +122,13 @@ void copy_general_general(
in.step();
out.step();
}
});
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst, Stream stream) {
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
src.shape(),
src.strides(),
dst.strides(),
@ -165,7 +142,6 @@ template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
@ -176,7 +152,6 @@ void copy_general(
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
data_shape,
i_strides,
make_contiguous_strides(data_shape),
@ -187,11 +162,10 @@ void copy_general(
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst, Stream stream) {
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
src.shape(),
src.strides(),
make_contiguous_strides(src.shape()),
@ -202,84 +176,67 @@ inline void copy_general(const array& src, array& dst, Stream stream) {
}
template <typename SrcT, typename DstT, typename... Args>
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst, stream);
copy_single<SrcT, DstT>(src, dst);
return;
case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst, stream);
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(
src, dst, stream, std::forward<Args>(args)...);
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
}
}
template <typename SrcT, typename... Args>
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
@ -289,50 +246,49 @@ inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
@ -340,7 +296,13 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
copy_inplace_dispatch(src, dst, ctype, stream);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
@ -368,6 +330,27 @@ void copy_inplace(
Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
if (x) {
return array::unsafe_weak_copy(*x);
} else {
return std::nullopt;
}
};
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
@ -375,7 +358,6 @@ void copy_inplace(
src,
dst,
ctype,
stream,
data_shape,
i_strides,
o_strides,
@ -386,8 +368,9 @@ void copy_inplace(
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype, stream);
copy_inplace_dispatch(src, dst, ctype);
}
});
}
} // namespace mlx::core

View File

@ -22,14 +22,47 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
struct None {
template <typename T>
void operator()(T x, T* y) {
(*y) = x;
}
};
struct Sum {
template <typename T>
void operator()(T x, T* y) {
(*y) += x;
}
};
struct Prod {
template <typename T>
void operator()(T x, T* y) {
(*y) *= x;
}
};
struct Max {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y > x) ? *y : x;
}
};
struct Min {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y < x) ? *y : x;
}
};
template <typename T, typename IdxT>
void gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes,
Stream stream) {
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
@ -82,43 +115,23 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
std::vector<const IdxT*> ind_ptrs;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([src_ptr,
dst_ptr,
ind_ptrs = std::move(ind_ptrs),
axes,
ind_size,
slice_size,
src_shape = src.shape(),
src_strides = src.strides(),
src_it = std::move(src_it),
its = std::move(its),
can_copy]() mutable {
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < ind_ptrs.size(); ++ii) {
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]);
src_idx += (idx_val * src_strides[ax]);
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx,
src_ptr + src_idx + slice_size,
dst_ptr + out_idx);
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
@ -128,7 +141,6 @@ void gather(
src_it.reset();
}
}
});
}
template <typename IdxT>
@ -137,50 +149,49 @@ void dispatch_gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& size,
Stream stream) {
const Shape& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size, stream);
gather<bool, IdxT>(src, inds, out, axes, size);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size, stream);
gather<uint8_t, IdxT>(src, inds, out, axes, size);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size, stream);
gather<uint16_t, IdxT>(src, inds, out, axes, size);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size, stream);
gather<uint32_t, IdxT>(src, inds, out, axes, size);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size, stream);
gather<uint64_t, IdxT>(src, inds, out, axes, size);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size, stream);
gather<int8_t, IdxT>(src, inds, out, axes, size);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size, stream);
gather<int16_t, IdxT>(src, inds, out, axes, size);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size, stream);
gather<int32_t, IdxT>(src, inds, out, axes, size);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size, stream);
gather<int64_t, IdxT>(src, inds, out, axes, size);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size, stream);
gather<float16_t, IdxT>(src, inds, out, axes, size);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size, stream);
gather<float, IdxT>(src, inds, out, axes, size);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size, stream);
gather<double, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size, stream);
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size, stream);
gather<complex64_t, IdxT>(src, inds, out, axes, size);
break;
}
}
@ -189,51 +200,63 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
inds.push_back(array::unsafe_weak_copy(*it));
}
auto& encoder = cpu::get_command_encoder(stream());
for (auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
slice_sizes_ = slice_sizes_,
src = array::unsafe_weak_copy(src),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_, stream());
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis,
Stream stream) {
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
@ -262,23 +285,6 @@ void gather_axis(
size_post *= ind.shape(i);
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_input_array(ind);
encoder.set_output_array(out);
encoder.dispatch([ind_ptr,
src_ptr,
dst_ptr,
size_pre,
size_post,
ind_ax_size,
src_ax_size,
ind_ax_stride,
src_ax_stride,
dst_ax_stride,
ind_it = std::move(ind_it),
src_it = std::move(src_it)]() mutable {
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
@ -293,7 +299,6 @@ void gather_axis(
}
dst_ptr += stride_pre;
}
});
}
template <typename IdxT>
@ -301,88 +306,97 @@ void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis,
Stream stream) {
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis, stream);
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis, stream);
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis, stream);
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis, stream);
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis, stream);
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis, stream);
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis, stream);
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis, stream);
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis, stream);
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis, stream);
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis, stream);
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis, stream);
gather_axis<double, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis, stream);
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis, stream);
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(src);
encoder.set_input_array(inds);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
src = array::unsafe_weak_copy(src),
inds = array::unsafe_weak_copy(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_, stream());
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename InT, typename IdxT, typename OpT>
@ -390,9 +404,7 @@ void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op,
Stream stream) {
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
@ -408,45 +420,27 @@ void scatter(
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
std::vector<const IdxT*> ind_ptrs;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(updates);
for (auto& idx : inds) {
ind_ptrs.push_back(idx.data<IdxT>());
encoder.set_input_array(idx);
}
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<InT>(),
upd_ptr = updates.data<InT>(),
ind_ptrs = std::move(ind_ptrs),
axes,
n_updates,
update_size,
op = std::move(op),
out_shape = out.shape(),
out_strides = out.strides(),
out_it = std::move(out_it),
update_it = std::move(update_it),
its = std::move(its)]() mutable {
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < ind_ptrs.size(); ++j) {
for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]);
out_offset += (idx_val * out_strides[ax]);
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
});
}
template <typename InT, typename IdxT>
@ -455,53 +449,22 @@ void dispatch_scatter_inds(
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype,
Stream stream) {
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = x; },
stream);
scatter<InT, IdxT, None>(updates, out, indices, axes);
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) += x; },
stream);
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) *= x; },
stream);
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
break;
case Scatter::Max:
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y > x) ? *y : x; },
stream);
scatter<InT, IdxT, Max>(updates, out, indices, axes);
break;
case Scatter::Min:
scatter<InT, IdxT>(
updates,
out,
indices,
axes,
[](auto x, auto* y) { (*y) = (*y < x) ? *y : x; },
stream);
scatter<InT, IdxT, Min>(updates, out, indices, axes);
break;
}
}
@ -512,46 +475,36 @@ void dispatch_scatter(
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype,
Stream stream) {
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_scatter_inds<InT, uint8_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(
out, inds, updates, axes, rtype, stream);
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
default:
throw std::runtime_error(
@ -563,7 +516,6 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
@ -571,73 +523,68 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
switch (src.dtype()) {
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
encoder.set_input_array(*it);
inds.push_back(array::unsafe_weak_copy(*it));
}
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
reduce_type_ = reduce_type_,
updates = array::unsafe_weak_copy(updates),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(
out, inds, updates, axes_, reduce_type_, stream());
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
});
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op,
Stream stream) {
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
@ -657,11 +604,6 @@ void scatter_axis(
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(idx);
encoder.set_input_array(upd);
encoder.set_output_array(out);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
@ -670,26 +612,14 @@ void scatter_axis(
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
encoder.dispatch([idx_ptr,
upd_ptr,
dst_ptr,
size_pre,
size_post,
idx_ax_size,
dst_ax_size,
idx_ax_stride,
upd_ax_stride,
dst_ax_stride,
idx_it = std::move(idx_it),
upd_it = std::move(upd_it),
op = std::move(op)]() mutable {
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
OpT{}(
upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
@ -697,7 +627,6 @@ void scatter_axis(
}
dst_ptr += stride_pre;
}
});
}
template <typename InT, typename IdxT>
@ -706,16 +635,13 @@ void dispatch_scatter_axis_op(
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype,
Stream stream) {
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream);
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream);
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
break;
}
}
@ -726,40 +652,31 @@ void dispatch_scatter_axis(
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype,
Stream stream) {
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(
out, idx, updates, axis, rtype, stream);
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
@ -779,64 +696,63 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
switch (src.dtype()) {
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
reduce_type_ = reduce_type_,
idx = array::unsafe_weak_copy(idx),
updates = array::unsafe_weak_copy(updates),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(
out, idx, updates, axis_, reduce_type_, stream());
out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(
out, idx, updates, axis_, reduce_type_, stream());
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_, stream());
out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_, stream());
out, idx, updates, axis_, reduce_type_);
break;
}
});
}
} // namespace mlx::core

View File

@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
const array& biases,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
@ -335,48 +334,18 @@ void _qmm_dispatch_typed(
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
M,
N,
K,
@ -384,7 +353,6 @@ void _qmm_dispatch_typed(
group_size,
transposed_w);
}
});
}
void _qmm_dispatch(
@ -395,20 +363,19 @@ void _qmm_dispatch(
const array& biases,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
default:
throw std::invalid_argument(
@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
@ -453,45 +410,19 @@ void _bs_qmm_dispatch_typed(
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
lhs_indices_ptr,
rhs_indices_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
w_els,
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M,
N,
K,
@ -499,7 +430,6 @@ void _bs_qmm_dispatch_typed(
group_size,
transposed_w);
}
});
}
void _bs_qmm_dispatch(
@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
switch (x.dtype()) {
case float32:
_bs_qmm_dispatch_typed<float>(
@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
case float16:
_bs_qmm_dispatch_typed<float16_t>(
@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>(
@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
default:
throw std::invalid_argument(
@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -626,6 +566,26 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
@ -636,10 +596,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
rhs_indices,
group_size_,
bits_,
transpose_,
stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
transpose_);
});
}
template <typename T, typename U>
@ -709,27 +667,13 @@ void dispatch_quantize(
array& scales,
array& biases,
int bits,
int group_size,
Stream stream) {
int group_size) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
});
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);
}
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(w);
}
});
}
} // namespace mlx::core

View File

@ -140,34 +140,23 @@ void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U init) {
ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
encoder.dispatch(
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
return;
}
@ -178,24 +167,14 @@ void reduction_op(
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
@ -211,7 +190,6 @@ void reduction_op(
plan.strides);
}
}
});
return;
}
@ -220,20 +198,12 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
return;
}
@ -245,17 +215,8 @@ void reduction_op(
plan.strides.pop_back();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
@ -263,7 +224,7 @@ void reduction_op(
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
@ -280,21 +241,13 @@ void reduction_op(
out_ptr += reduction_stride;
}
}
});
return;
}
if (plan.type == GeneralReduce) {
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
@ -305,7 +258,6 @@ void reduction_op(
plan.strides);
*out_ptr = val;
}
});
}
}
@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else {
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
}
}
@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else {
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else {
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
}
}
}
@ -470,20 +420,27 @@ void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
}
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axes_ = axes_]() mutable {
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
@ -491,28 +448,24 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
@ -523,43 +476,34 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
@ -568,64 +512,52 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
}
});
}
} // namespace mlx::core

View File

@ -160,38 +160,29 @@ void scan_op(
bool reverse,
bool inclusive,
const Op& op,
U init,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
U init) {
if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) {
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis),
stride = in.shape(axis),
reverse,
inclusive,
op = std::move(op),
init]() {
contiguous_scan(
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
});
} else {
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis) / in.strides()[axis],
size = in.shape(axis),
stride = in.strides()[axis],
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis),
in.shape(axis),
reverse,
inclusive,
op = std::move(op),
init]() {
op,
init);
} else {
strided_scan(
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
});
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis) / in.strides()[axis],
in.shape(axis),
in.strides()[axis],
reverse,
inclusive,
op,
init);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
@ -205,19 +196,18 @@ void scan_dispatch(
array& out,
int axis,
bool reverse,
bool inclusive,
Stream stream) {
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Min: {
@ -225,7 +215,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Max: {
@ -233,7 +223,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
@ -244,17 +234,26 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity
auto in = inputs[0];
bool copied = false;
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
copied = true;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
reduce_type_ = reduce_type_,
reverse_ = reverse_,
inclusive_ = inclusive_]() mutable {
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
@ -264,68 +263,66 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
}
});
}
} // namespace mlx::core

View File

@ -16,51 +16,70 @@ void select_op(
const array& b,
const array& c,
array& out,
Op op) {
Op op,
Stream stream) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
c = array::unsafe_weak_copy(c),
out = array::unsafe_weak_copy(out),
op,
topt]() mutable {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break;
}
});
}
} // namespace
@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
select_op(condition, a, b, out, detail::Select(), stream());
}
} // namespace mlx::core

View File

@ -105,15 +105,11 @@ struct StridedIterator {
};
template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@ -127,13 +123,7 @@ void sort(const array& in, array& out, int axis, Stream stream) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
@ -143,14 +133,10 @@ void sort(const array& in, array& out, int axis, Stream stream) {
std::stable_sort(st, ed);
src_it.step();
}
});
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis, Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@ -176,17 +162,8 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
@ -210,42 +187,30 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
auto axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
@ -256,19 +221,10 @@ void partition(const array& in, array& out, int axis, int kth, Stream stream) {
std::nth_element(st, md, ed);
}
});
}
template <typename T, typename IdxT = uint32_t>
void argpartition(
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@ -297,18 +253,9 @@ void argpartition(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
@ -332,7 +279,6 @@ void argpartition(
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
} // namespace
@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_, stream());
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_, stream());
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_, stream());
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_, stream());
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_, stream());
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_, stream());
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_, stream());
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_, stream());
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_, stream());
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_, stream());
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_, stream());
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_, stream());
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_, stream());
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_, stream());
return argsort<complex64_t>(in, out, axis_);
}
});
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(in, out, axis_, stream());
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(in, out, axis_, stream());
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(in, out, axis_, stream());
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(in, out, axis_, stream());
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(in, out, axis_, stream());
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(in, out, axis_, stream());
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(in, out, axis_, stream());
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(in, out, axis_, stream());
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(in, out, axis_, stream());
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(in, out, axis_, stream());
return sort<float>(out, axis_);
case float64:
return sort<double>(in, out, axis_, stream());
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(in, out, axis_, stream());
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(in, out, axis_, stream());
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(in, out, axis_, stream());
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_, stream());
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_, stream());
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_, stream());
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_, stream());
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_, stream());
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_, stream());
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_, stream());
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_, stream());
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (out.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_, stream());
return partition<bool>(out, axis_, kth_);
case uint8:
return partition<uint8_t>(in, out, axis_, kth_, stream());
return partition<uint8_t>(out, axis_, kth_);
case uint16:
return partition<uint16_t>(in, out, axis_, kth_, stream());
return partition<uint16_t>(out, axis_, kth_);
case uint32:
return partition<uint32_t>(in, out, axis_, kth_, stream());
return partition<uint32_t>(out, axis_, kth_);
case uint64:
return partition<uint64_t>(in, out, axis_, kth_, stream());
return partition<uint64_t>(out, axis_, kth_);
case int8:
return partition<int8_t>(in, out, axis_, kth_, stream());
return partition<int8_t>(out, axis_, kth_);
case int16:
return partition<int16_t>(in, out, axis_, kth_, stream());
return partition<int16_t>(out, axis_, kth_);
case int32:
return partition<int32_t>(in, out, axis_, kth_, stream());
return partition<int32_t>(out, axis_, kth_);
case int64:
return partition<int64_t>(in, out, axis_, kth_, stream());
return partition<int64_t>(out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_, stream());
return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(in, out, axis_, kth_, stream());
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_, stream());
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(in, out, axis_, kth_, stream());
return partition<complex64_t>(out, axis_, kth_);
}
});
}
} // namespace mlx::core

View File

@ -1,12 +1,10 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -128,57 +126,28 @@ void ternary_op(
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
Op op,
TernaryOpType topt) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
if (topt == TernaryOpType::ScalarScalarScalar) {
encoder.dispatch(
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
});
} else if (topt == TernaryOpType::VectorVectorVector) {
encoder.dispatch([a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size()]() mutable {
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
});
} else {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
encoder.dispatch(
[a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size(),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
});
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
}
}

View File

@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
unary_signed(in, out, detail::Abs(), stream());
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos());
unary_fp(in, out, detail::ArcCos(), stream());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh());
unary_fp(in, out, detail::ArcCosh(), stream());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin());
unary_fp(in, out, detail::ArcSin(), stream());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh());
unary_fp(in, out, detail::ArcSinh(), stream());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan());
unary_fp(in, out, detail::ArcTan(), stream());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh());
unary_fp(in, out, detail::ArcTanh(), stream());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert());
unary_int(in, out, detail::BitwiseInvert(), stream());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
unary_fp(in, out, detail::Ceil(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
unary_complex(inputs[0], out, detail::Conjugate(), stream());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos());
unary_fp(in, out, detail::Cos(), stream());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh());
unary_fp(in, out, detail::Cosh(), stream());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
unary_real_fp(in, out, detail::Erf(), stream());
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
unary_real_fp(in, out, detail::ErfInv(), stream());
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp());
unary_fp(in, out, detail::Exp(), stream());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1());
unary_fp(in, out, detail::Expm1(), stream());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
unary_fp(in, out, detail::Floor(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
unary_fp(in, out, detail::Log(), stream());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
unary_fp(in, out, detail::Log2(), stream());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
unary_fp(in, out, detail::Log10(), stream());
break;
}
}
@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p());
unary_fp(in, out, detail::Log1p(), stream());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
unary(in, out, detail::LogicalNot(), stream());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
unary(in, out, detail::Negative(), stream());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
unary_fp(in, out, detail::Round(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid());
unary_fp(in, out, detail::Sigmoid(), stream());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
unary(in, out, detail::Sign(), stream());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin());
unary_fp(in, out, detail::Sin(), stream());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh());
unary_fp(in, out, detail::Sinh(), stream());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
unary(in, out, detail::Square(), stream());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
unary_fp(in, out, detail::Rsqrt(), stream());
} else {
unary_fp(in, out, detail::Sqrt());
unary_fp(in, out, detail::Sqrt(), stream());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan());
unary_fp(in, out, detail::Tan(), stream());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh());
unary_fp(in, out, detail::Tanh(), stream());
}
} // namespace mlx::core

View File

@ -7,7 +7,6 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@ -39,53 +38,48 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>();
U* dst = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([src,
dst,
contig = a.flags().contiguous,
data_size = a.data_size(),
size = a.size(),
shapes = a.shape(),
strides = a.strides()]() mutable {
auto ndim = shapes.size();
if (contig) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
while (data_size >= N) {
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
data_size -= N;
size -= N;
src += N;
dst += N;
}
while (data_size > 0) {
while (size > 0) {
*dst = Op{}(*src);
data_size--;
size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? shapes.back() : 1;
size_t stride = ndim > 0 ? strides.back() : 1;
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(shapes, strides, ndim - 1);
for (size_t elem = 0; elem < size; elem += shape) {
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
});
}
template <typename Op>
void unary(const array& a, array& out, Op op) {
void unary(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
@ -130,10 +124,47 @@ void unary(const array& a, array& out, Op op) {
unary_op<complex64_t>(a, out, op);
break;
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op) {
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_real] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
@ -155,10 +186,84 @@ void unary_fp(const array& a, array& out, Op op) {
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_int(const array& a, array& out, Op op) {
void unary_signed(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
}
template <typename Op>
void unary_complex(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
}
template <typename Op>
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch(
[a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
}
template <typename Op>
void unary_int(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
@ -189,6 +294,7 @@ void unary_int(const array& a, array& out, Op op) {
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
} // namespace mlx::core