reduce binary size (#1952)

This commit is contained in:
Awni Hannun 2025-03-11 06:30:44 -07:00 committed by GitHub
parent 117e1355a2
commit 736a340478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2145 additions and 2386 deletions

View File

@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs; return outputs;
} }
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data) array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},

View File

@ -199,6 +199,13 @@ class array {
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */ /** A unique identifier for an array. */
std::uintptr_t id() const { std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get()); return reinterpret_cast<std::uintptr_t>(array_desc_.get());

View File

@ -11,12 +11,7 @@ namespace mlx::core {
namespace { namespace {
template <typename InT, typename OpT> template <typename InT, typename OpT>
void arg_reduce( void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
const array& in,
array& out,
const OpT& op,
int axis,
Stream stream) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
Strides strides = in.strides(); Strides strides = in.strides();
@ -26,28 +21,16 @@ void arg_reduce(
auto in_ptr = in.data<InT>(); auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>(); auto out_ptr = out.data<uint32_t>();
auto& encoder = cpu::get_command_encoder(stream); for (uint32_t i = 0; i < out.size(); ++i) {
encoder.set_input_array(in); auto loc = elem_to_loc(i, shape, strides);
encoder.set_output_array(out); auto local_in_ptr = in_ptr + loc;
encoder.dispatch([in_ptr, uint32_t ind_v = 0;
out_ptr, InT v = (*local_in_ptr);
axis_size, for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
axis_stride, op(j, (*local_in_ptr), &ind_v, &v);
op = std::move(op),
shape = std::move(shape),
strides = std::move(strides),
size = out.size()]() {
for (uint32_t i = 0; i < size; ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;
} }
}); out_ptr[i] = ind_v;
}
} }
template <typename InT> template <typename InT>
@ -55,8 +38,7 @@ void arg_reduce_dispatch(
const array& in, const array& in,
array& out, array& out,
ArgReduce::ReduceType rtype, ArgReduce::ReduceType rtype,
int axis, int axis) {
Stream stream) {
switch (rtype) { switch (rtype) {
case ArgReduce::ArgMin: { case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) { auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
@ -65,7 +47,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x; (*ind_y) = ind_x;
} }
}; };
arg_reduce<InT>(in, out, op, axis, stream); arg_reduce<InT>(in, out, op, axis);
break; break;
} }
case ArgReduce::ArgMax: { case ArgReduce::ArgMax: {
@ -75,7 +57,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x; (*ind_y) = ind_x;
} }
}; };
arg_reduce<InT>(in, out, op, axis, stream); arg_reduce<InT>(in, out, op, axis);
break; break;
} }
} }
@ -87,51 +69,58 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
switch (in.dtype()) { encoder.set_input_array(in);
case bool_: encoder.set_output_array(out);
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream()); encoder.dispatch([in = array::unsafe_weak_copy(in),
break; out = array::unsafe_weak_copy(out),
case uint8: reduce_type_ = reduce_type_,
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream()); axis_ = axis_]() mutable {
break; switch (in.dtype()) {
case uint16: case bool_:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break; break;
case uint32: case uint8:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break; break;
case uint64: case uint16:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break; break;
case int8: case uint32:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break; break;
case int16: case uint64:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break; break;
case int32: case int8:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break; break;
case int64: case int16:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break; break;
case float16: case int32:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break; break;
case float32: case int64:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break; break;
case bfloat16: case float16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break; break;
case float64: case float32:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break; break;
case complex64: case bfloat16:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream()); arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break; break;
} case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h" #include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@ -16,51 +17,218 @@ namespace mlx::core {
namespace { namespace {
template <typename Op> template <typename Op>
void comparison_op(const array& a, const array& b, array& out) { void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
switch (a.dtype()) { auto bopt = get_binary_op_type(a, b);
case bool_: set_binary_op_output_data(a, b, out, bopt);
binary_op<bool, bool, Op>(a, b, out);
break; auto& encoder = cpu::get_command_encoder(stream);
case uint8: encoder.set_input_array(a);
binary_op<uint8_t, bool, Op>(a, b, out); encoder.set_input_array(b);
break; encoder.set_output_array(out);
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<uint16_t, bool, Op>(a, b, out); b = array::unsafe_weak_copy(b),
break; out = array::unsafe_weak_copy(out),
case uint32: bopt]() mutable {
binary_op<uint32_t, bool, Op>(a, b, out); switch (out.dtype()) {
break; case bool_:
case uint64: binary_op<bool, Op>(a, b, out, bopt);
binary_op<uint64_t, bool, Op>(a, b, out); break;
break; case uint8:
case int8: binary_op<uint8_t, Op>(a, b, out, bopt);
binary_op<int8_t, bool, Op>(a, b, out); break;
break; case uint16:
case int16: binary_op<uint16_t, Op>(a, b, out, bopt);
binary_op<int16_t, bool, Op>(a, b, out); break;
break; case uint32:
case int32: binary_op<uint32_t, Op>(a, b, out, bopt);
binary_op<int32_t, bool, Op>(a, b, out); break;
break; case uint64:
case int64: binary_op<uint64_t, Op>(a, b, out, bopt);
binary_op<int64_t, bool, Op>(a, b, out); break;
break; case int8:
case float16: binary_op<int8_t, Op>(a, b, out, bopt);
binary_op<float16_t, bool, Op>(a, b, out); break;
break; case int16:
case float32: binary_op<int16_t, Op>(a, b, out, bopt);
binary_op<float, bool, Op>(a, b, out); break;
break; case int32:
case float64: binary_op<int32_t, Op>(a, b, out, bopt);
binary_op<double, bool, Op>(a, b, out); break;
break; case int64:
case bfloat16: binary_op<int64_t, Op>(a, b, out, bopt);
binary_op<bfloat16_t, bool, Op>(a, b, out); break;
break; case float16:
case complex64: binary_op<float16_t, Op>(a, b, out, bopt);
binary_op<complex64_t, bool, Op>(a, b, out); break;
break; case float32:
} binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
} }
} // namespace } // namespace
@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Add()); binary(a, b, out, detail::Add(), stream());
} }
void DivMod::eval_cpu( void DivMod::eval_cpu(
@ -78,70 +246,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto integral_op = [](auto x, auto y) { auto bopt = get_binary_op_type(a, b);
return std::make_pair(x / y, x % y); auto& out_a = outputs[0];
}; auto& out_b = outputs[1];
auto float_op = [](auto x, auto y) { set_binary_op_output_data(a, b, out_a, bopt);
return std::make_pair(std::trunc(x / y), std::fmod(x, y)); set_binary_op_output_data(a, b, out_b, bopt);
};
switch (outputs[0].dtype()) { auto& encoder = cpu::get_command_encoder(stream());
case bool_: encoder.set_input_array(a);
binary_op<bool>(a, b, outputs, integral_op); encoder.set_input_array(b);
case uint8: encoder.set_output_array(out_a);
binary_op<uint8_t>(a, b, outputs, integral_op); encoder.set_output_array(out_b);
break;
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<uint16_t>(a, b, outputs, integral_op); b = array::unsafe_weak_copy(b),
break; out_a = array::unsafe_weak_copy(out_a),
case uint32: out_b = array::unsafe_weak_copy(out_b),
binary_op<uint32_t>(a, b, outputs, integral_op); bopt]() mutable {
break; auto integral_op = [](auto x, auto y) {
case uint64: return std::make_pair(x / y, x % y);
binary_op<uint64_t>(a, b, outputs, integral_op); };
break; auto float_op = [](auto x, auto y) {
case int8: return std::make_pair(std::trunc(x / y), std::fmod(x, y));
binary_op<int8_t>(a, b, outputs, integral_op); };
break;
case int16: switch (out_a.dtype()) {
binary_op<int16_t>(a, b, outputs, integral_op); case bool_:
break; binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case int32: case uint8:
binary_op<int32_t>(a, b, outputs, integral_op); binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case int64: case uint16:
binary_op<int64_t>(a, b, outputs, integral_op); binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float16: case uint32:
binary_op<float16_t>(a, b, outputs, float_op); binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float32: case uint64:
binary_op<float>(a, b, outputs, float_op); binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case float64: case int8:
binary_op<double>(a, b, outputs, float_op); binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case bfloat16: case int16:
binary_op<bfloat16_t>(a, b, outputs, float_op); binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break; break;
case complex64: case int32:
// Should never get here binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
throw std::runtime_error("[DivMod] Complex type not supported"); break;
break; case int64:
} binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
} }
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) { void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Divide()); binary(a, b, out, detail::Divide(), stream());
} }
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Remainder()); binary(a, b, out, detail::Remainder(), stream());
} }
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
if (equal_nan_) { if (equal_nan_) {
switch (a.dtype()) { auto bopt = get_binary_op_type(a, b);
case float16: set_binary_op_output_data(a, b, out, bopt);
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
break; auto& encoder = cpu::get_command_encoder(stream());
case float32: encoder.set_input_array(a);
binary_op<float, bool, detail::NaNEqual>(a, b, out); encoder.set_input_array(b);
break; encoder.set_output_array(out);
case float64: encoder.dispatch([a = array::unsafe_weak_copy(a),
binary_op<double, bool, detail::NaNEqual>(a, b, out); b = array::unsafe_weak_copy(b),
break; out = array::unsafe_weak_copy(out),
case bfloat16: bopt]() mutable {
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out); switch (a.dtype()) {
break; case float16:
case complex64: binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out); break;
break; case float32:
default: binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
throw std::runtime_error( break;
"[NanEqual::eval_cpu] Only for floating point types."); case float64:
} binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
} else { } else {
comparison_op<detail::Equal>(a, b, out); comparison_op(a, b, out, detail::Equal(), stream());
} }
} }
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::Greater>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
} }
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
} }
void Less::eval_cpu(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::Less>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
} }
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
} }
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::LogAddExp(), stream());
case float16:
binary_op<float16_t, detail::LogAddExp>(a, b, out);
break;
case float32:
binary_op<float, detail::LogAddExp>(a, b, out);
break;
case float64:
binary_op<double, detail::LogAddExp>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
} }
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd()); binary(in1, in2, out, detail::LogicalAnd(), stream());
} }
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr()); binary(in1, in2, out, detail::LogicalOr(), stream());
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Maximum()); binary(a, b, out, detail::Maximum(), stream());
} }
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Minimum()); binary(a, b, out, detail::Minimum(), stream());
} }
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Multiply()); binary(a, b, out, detail::Multiply(), stream());
} }
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out); comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
} }
void Power::eval_cpu(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Power()); binary(a, b, out, detail::Power(), stream());
} }
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Subtract()); binary(a, b, out, detail::Subtract(), stream());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd()); binary_int(a, b, out, detail::BitwiseAnd(), stream());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr()); binary_int(a, b, out, detail::BitwiseOr(), stream());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor()); binary_int(a, b, out, detail::BitwiseXor(), stream());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift()); binary_int(a, b, out, detail::LeftShift(), stream());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift()); binary_int(a, b, out, detail::RightShift(), stream());
break; break;
} }
} }
@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
switch (out.dtype()) { binary_float(a, b, out, detail::ArcTan2(), stream());
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,12 +3,9 @@
#pragma once #pragma once
#include <cassert> #include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@ -152,218 +149,145 @@ void binary_op_dispatch_dims(
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out) { void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>(); auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream()); if (bopt == BinaryOpType::ScalarScalar) {
encoder.set_input_array(a); *out_ptr = Op{}(*a_ptr, *b_ptr);
encoder.set_input_array(b); return;
encoder.set_output_array(out); }
encoder.dispatch([bopt,
a_ptr, // The full computation is scalar vector so delegate to the op
b_ptr, if (bopt == BinaryOpType::ScalarVector) {
out_ptr, ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
a_data_size = a.data_size(), return;
b_data_size = b.data_size(), }
size = a.size(),
shape = a.shape(), // The full computation is vector scalar so delegate to the op
a_strides = a.strides(), if (bopt == BinaryOpType::VectorScalar) {
b_strides = b.strides(), VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
strides = out.strides()]() mutable { return;
if (bopt == BinaryOpType::ScalarScalar) { }
*out_ptr = Op{}(*a_ptr, *b_ptr);
return; // The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
auto& a_strides = new_strides[0];
auto& b_strides = new_strides[1];
auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
} }
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// The full computation is scalar vector so delegate to the op // Get the left-most dim such that the array is a broadcasted "scalar" after
if (bopt == BinaryOpType::ScalarVector) { auto leftmost_s_dim = [](const auto& arr_strides) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size); int d = arr_strides.size() - 1;
return; for (; d >= 0 && arr_strides[d] == 0; d--) {
} }
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
// The full computation is vector scalar so delegate to the op auto ndim = new_shape.size();
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
return;
}
// The full computation is vector vector so delegate to the op // Case 1: LxM and FxM where L and F are broadcastable and M is row
if (bopt == BinaryOpType::VectorVector) { // contiguous
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size); int dim = ndim;
return; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
} bopt = BinaryOpType::VectorVector;
dim = d;
// General computation so let's try to optimize // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
auto [new_shape, new_strides] = collapse_contiguous_dims(
shape,
{std::move(a_strides), std::move(b_strides), std::move(strides)});
a_strides = new_strides[0];
b_strides = new_strides[1];
strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous // contiguous
int dim = ndim; } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { bopt = BinaryOpType::VectorScalar;
bopt = BinaryOpType::VectorVector; dim = d;
dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous
// contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { bopt = BinaryOpType::ScalarVector;
bopt = BinaryOpType::VectorScalar; dim = d;
dim = d; }
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully // Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not // contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity. // correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) { if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General; bopt = BinaryOpType::General;
dim = ndim; dim = ndim;
} }
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>( binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr, a_ptr,
b_ptr, b_ptr,
out_ptr, out_ptr,
dim, dim,
size, a.size(),
new_shape, new_shape,
a_strides, a_strides,
b_strides, b_strides,
strides); strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
}
});
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out) {
binary_op<T, T, Op>(a, b, out);
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T, Op>(a, b, out);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out);
break; break;
case uint8: case BinaryOpType::VectorScalar:
binary_op<uint8_t, Op>(a, b, out); binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break; break;
case uint16: case BinaryOpType::ScalarVector:
binary_op<uint16_t, Op>(a, b, out); binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break; break;
case uint32: default:
binary_op<uint32_t, Op>(a, b, out); binary_op_dispatch_dims<T, U, false, Op>(
break; a_ptr,
case uint64: b_ptr,
binary_op<uint64_t, Op>(a, b, out); out_ptr,
break; dim,
case int8: a.size(),
binary_op<int8_t, Op>(a, b, out); new_shape,
break; a_strides,
case int16: b_strides,
binary_op<int16_t, Op>(a, b, out); strides);
break;
case int32:
binary_op<int32_t, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, Op>(a, b, out);
break;
case float32:
binary_op<float, Op>(a, b, out);
break;
case float64:
binary_op<double, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out);
break; break;
} }
} }
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -4,8 +4,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@ -57,14 +55,7 @@ void binary_op_dispatch_dims(
const array& b, const array& b,
array& out_a, array& out_a,
array& out_b, array& out_b,
Stream stream,
Op op) { Op op) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()}); a.shape(), {a.strides(), b.strides(), out_a.strides()});
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
@ -72,197 +63,101 @@ void binary_op_dispatch_dims(
U* out_a_ptr = out_a.data<U>(); U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>(); U* out_b_ptr = out_b.data<U>();
encoder.dispatch([a_ptr, const auto& a_strides = strides[0];
b_ptr, const auto& b_strides = strides[1];
out_a_ptr, const auto& out_strides = strides[2];
out_b_ptr, int ndim = shape.size();
size = a.size(), switch (ndim) {
shape = std::move(shape), case 1:
strides = std::move(strides), binary_op_dims<T, U, Op, 1>(
op = std::move(op)]() { a_ptr,
const auto& a_strides = strides[0]; b_ptr,
const auto& b_strides = strides[1]; out_a_ptr,
const auto& out_strides = strides[2]; out_b_ptr,
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op, op,
shape, shape,
a_strides, a_strides,
b_strides, b_strides,
out_strides, out_strides,
ndim - 2); 0);
a_it.step(); return;
b_it.step(); case 2:
} binary_op_dims<T, U, Op, 2>(
}); a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
} }
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, array& out_a,
Op op) { array& out_b,
auto bopt = get_binary_op_type(a, b); Op op,
auto& out_a = outputs[0]; BinaryOpType bopt) {
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto stream = out_a.primitive().stream();
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op); binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return; return;
} }
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>(); auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>(); auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>(); auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
encoder.dispatch( std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
});
} else if (bopt == BinaryOpType::ScalarVector) { } else if (bopt == BinaryOpType::ScalarVector) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < b.data_size(); ++i) {
b_ptr, std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr, out_a_ptr++;
out_b_ptr, out_b_ptr++;
size = b.size(), b_ptr++;
op = std::move(op)]() mutable { }
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
});
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < a.data_size(); ++i) {
b_ptr, std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr, out_a_ptr++;
out_b_ptr, out_b_ptr++;
size = a.size(), a_ptr++;
op = std::move(op)]() mutable { }
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
});
} else { // VectorVector } else { // VectorVector
encoder.dispatch([a_ptr, for (size_t i = 0; i < a.size(); ++i) {
b_ptr, std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr, out_a_ptr++;
out_b_ptr, out_b_ptr++;
size = a.size(), a_ptr++;
op = std::move(op)]() mutable { b_ptr++;
for (size_t i = 0; i < size; ++i) { }
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
});
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
} }
} }

View File

@ -13,29 +13,20 @@ namespace mlx::core {
namespace { namespace {
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst, Stream stream) { void copy_single(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
auto& encoder = cpu::get_command_encoder(stream); auto size = dst.size();
encoder.set_input_array(src); auto val = static_cast<DstT>(src_ptr[0]);
encoder.set_output_array(dst); std::fill_n(dst_ptr, size, val);
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
});
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst, Stream stream) { void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size(); auto size = src.data_size();
auto& encoder = cpu::get_command_encoder(stream); std::copy(src_ptr, src_ptr + size, dst_ptr);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
std::copy(src_ptr, src_ptr + size, dst_ptr);
});
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT, int D>
@ -66,7 +57,6 @@ template <typename SrcT, typename DstT>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
Stream stream,
const Shape& data_shape, const Shape& data_shape,
const Strides& i_strides, const Strides& i_strides,
const Strides& o_strides, const Strides& o_strides,
@ -80,47 +70,17 @@ void copy_general_general(
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr; dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr = auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr; dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto& encoder = cpu::get_command_encoder(stream); int ndim = shape.size();
encoder.set_input_array(src); if (ndim < 3) {
encoder.set_output_array(dst);
encoder.dispatch([src_ptr,
dst_ptr,
size = src.size(),
data_shape = data_shape,
i_strides = i_strides,
o_strides = o_strides,
i_offset_ptr,
o_offset_ptr]() mutable {
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
int ndim = shape.size();
if (ndim < 3) {
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return;
}
if (i_offset_ptr) { if (i_offset_ptr) {
src_ptr += i_offset_ptr[0]; src_ptr += i_offset_ptr[0];
} }
@ -128,30 +88,47 @@ void copy_general_general(
dst_ptr += o_offset_ptr[0]; dst_ptr += o_offset_ptr[0];
} }
ContiguousIterator in(shape, strides[0], ndim - 3); if (ndim == 1) {
ContiguousIterator out(shape, strides[1], ndim - 3); copy_dims<SrcT, DstT, 1>(
auto stride = std::accumulate( src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>()); } else if (ndim == 2) {
for (int64_t elem = 0; elem < size; elem += stride) { copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>( copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc, src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
} }
}); return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst, Stream stream) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
src.shape(), src.shape(),
src.strides(), src.strides(),
dst.strides(), dst.strides(),
@ -165,7 +142,6 @@ template <typename SrcT, typename DstT>
void copy_general( void copy_general(
const array& src, const array& src,
array& dst, array& dst,
Stream stream,
const Shape& data_shape, const Shape& data_shape,
const Strides& i_strides, const Strides& i_strides,
const Strides&, const Strides&,
@ -176,7 +152,6 @@ void copy_general(
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
data_shape, data_shape,
i_strides, i_strides,
make_contiguous_strides(data_shape), make_contiguous_strides(data_shape),
@ -187,11 +162,10 @@ void copy_general(
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst, Stream stream) { inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
stream,
src.shape(), src.shape(),
src.strides(), src.strides(),
make_contiguous_strides(src.shape()), make_contiguous_strides(src.shape()),
@ -202,84 +176,67 @@ inline void copy_general(const array& src, array& dst, Stream stream) {
} }
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy( void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst, stream); copy_single<SrcT, DstT>(src, dst);
return; return;
case CopyType::Vector: case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst, stream); copy_vector<SrcT, DstT>(src, dst);
return; return;
case CopyType::General: case CopyType::General:
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...); copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
src, dst, stream, std::forward<Args>(args)...);
return; return;
} }
} }
template <typename SrcT, typename... Args> template <typename SrcT, typename... Args>
void copy( void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (dst.dtype()) { switch (dst.dtype()) {
case bool_: case bool_:
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<SrcT, uint16_t>( copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<SrcT, uint32_t>( copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<SrcT, uint64_t>( copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<SrcT, float16_t>( copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64: case float64:
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>( copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<SrcT, complex64_t>( copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
src, dst, ctype, stream, std::forward<Args>(args)...);
break; break;
} }
} }
@ -289,50 +246,49 @@ inline void copy_inplace_dispatch(
const array& src, const array& src,
array& dst, array& dst,
CopyType ctype, CopyType ctype,
Stream stream,
Args&&... args) { Args&&... args) {
switch (src.dtype()) { switch (src.dtype()) {
case bool_: case bool_:
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64: case float64:
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...); copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
} }
} }
@ -340,7 +296,13 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
copy_inplace_dispatch(src, dst, ctype, stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy(const array& src, array& dst, CopyType ctype, Stream stream) { void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
@ -368,26 +330,47 @@ void copy_inplace(
Stream stream, Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */ const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) { const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
switch (ctype) { auto& encoder = cpu::get_command_encoder(stream);
case CopyType::General: encoder.set_input_array(src);
case CopyType::GeneralGeneral: encoder.set_output_array(dst);
copy_inplace_dispatch( auto weak_copy_if_set = [](auto x) -> std::optional<array> {
src, if (x) {
dst, return array::unsafe_weak_copy(*x);
ctype, } else {
stream, return std::nullopt;
data_shape, }
i_strides, };
o_strides, encoder.dispatch(
i_offset, [src = array::unsafe_weak_copy(src),
o_offset, dst = array::unsafe_weak_copy(dst),
dynamic_i_offset, data_shape,
dynamic_o_offset); i_strides,
break; o_strides,
case CopyType::Scalar: i_offset,
case CopyType::Vector: o_offset,
copy_inplace_dispatch(src, dst, ctype, stream); ctype,
} dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
} }
} // namespace mlx::core } // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
const array& biases, const array& biases,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1; int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1); int N = out.shape(-1);
@ -335,56 +334,25 @@ void _qmm_dispatch_typed(
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M); int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>(); auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>(); auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>(); auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>(); auto biases_ptr = biases.data<T>();
for (int i = 0; i < batch_size; i++) {
encoder.dispatch([out_ptr, _qmm_dispatch_typed<T>(
x_ptr, out_ptr + i * M * N,
w_ptr, x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
scales_ptr, w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
biases_ptr, scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
x_shape = x.shape(), biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
x_strides = x.strides(), M,
w_shape = w.shape(), N,
w_strides = w.strides(), K,
scales_shape = scales.shape(), bits,
scales_strides = scales.strides(), group_size,
biases_shape = biases.shape(), transposed_w);
biases_strides = biases.strides(), }
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
} }
void _qmm_dispatch( void _qmm_dispatch(
@ -395,20 +363,19 @@ void _qmm_dispatch(
const array& biases, const array& biases,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
_qmm_dispatch_typed<float>( _qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
case float16: case float16:
_qmm_dispatch_typed<float16_t>( _qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
case bfloat16: case bfloat16:
_qmm_dispatch_typed<bfloat16_t>( _qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream); out, x, w, scales, biases, bits, group_size, transposed_w);
break; break;
default: default:
throw std::invalid_argument( throw std::invalid_argument(
@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
const array& rhs_indices, const array& rhs_indices,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.shape(-2); int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
int w_els = w.shape(-1) * w.shape(-2); int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>(); auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>(); auto w_ptr = w.data<uint32_t>();
@ -453,53 +410,26 @@ void _bs_qmm_dispatch_typed(
auto lhs_indices_ptr = lhs_indices.data<uint32_t>(); auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>(); auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.dispatch([out_ptr, for (int i = 0; i < lhs_indices.size(); i++) {
x_ptr, int x_idx = lhs_indices_ptr[elem_to_loc(
w_ptr, i, lhs_indices.shape(), lhs_indices.strides())];
scales_ptr, int w_idx = rhs_indices_ptr[elem_to_loc(
biases_ptr, i, rhs_indices.shape(), rhs_indices.strides())];
lhs_indices_ptr, _qmm_dispatch_typed<T>(
rhs_indices_ptr, out_ptr + i * M * N,
x_shape = x.shape(), x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
x_strides = x.strides(), w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
w_shape = w.shape(), scales_ptr +
w_strides = w.strides(), elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
scales_shape = scales.shape(), biases_ptr +
scales_strides = scales.strides(), elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
biases_shape = biases.shape(), M,
biases_strides = biases.strides(), N,
lhs_indices_shape = lhs_indices.shape(), K,
lhs_indices_strides = lhs_indices.strides(), bits,
rhs_indices_shape = rhs_indices.shape(), group_size,
rhs_indices_strides = rhs_indices.strides(), transposed_w);
w_els, }
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
} }
void _bs_qmm_dispatch( void _bs_qmm_dispatch(
@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
const array& rhs_indices, const array& rhs_indices,
int bits, int bits,
int group_size, int group_size,
bool transposed_w, bool transposed_w) {
Stream stream) {
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
_bs_qmm_dispatch_typed<float>( _bs_qmm_dispatch_typed<float>(
@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
case float16: case float16:
_bs_qmm_dispatch_typed<float16_t>( _bs_qmm_dispatch_typed<float16_t>(
@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
case bfloat16: case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>( _bs_qmm_dispatch_typed<bfloat16_t>(
@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
rhs_indices, rhs_indices,
bits, bits,
group_size, group_size,
transposed_w, transposed_w);
stream);
break; break;
default: default:
throw std::invalid_argument( throw std::invalid_argument(
@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream()); auto& encoder = cpu::get_command_encoder(stream());
auto& enc = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps));
enc.add_temporaries(std::move(temps)); encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -626,20 +566,38 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out, auto& encoder = cpu::get_command_encoder(stream());
x, encoder.add_temporaries(std::move(temps));
w, encoder.set_input_array(x);
scales, encoder.set_input_array(w);
biases, encoder.set_input_array(scales);
lhs_indices, encoder.set_input_array(biases);
rhs_indices, encoder.set_input_array(lhs_indices);
group_size_, encoder.set_input_array(rhs_indices);
bits_, encoder.set_output_array(out);
transpose_, encoder.dispatch([out = array::unsafe_weak_copy(out),
stream()); x = array::unsafe_weak_copy(x),
auto& enc = cpu::get_command_encoder(stream()); w = array::unsafe_weak_copy(w),
enc.add_temporaries(std::move(temps)); scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} }
template <typename T, typename U> template <typename T, typename U>
@ -709,27 +667,13 @@ void dispatch_quantize(
array& scales, array& scales,
array& biases, array& biases,
int bits, int bits,
int group_size, int group_size) {
Stream stream) {
auto w_ptr = w.data<T>(); auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>(); auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>(); auto biases_ptr = biases.data<T>();
auto& encoder = cpu::get_command_encoder(stream); quantize<T, U>(
encoder.set_input_array(w); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
});
} }
void fast::AffineQuantize::eval_cpu( void fast::AffineQuantize::eval_cpu(
@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
auto& biases = outputs[2]; auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes())); scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes())); biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) { auto& encoder = cpu::get_command_encoder(stream());
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
if (copied) { if (copied) {
cpu::get_command_encoder(stream()).add_temporary(w); encoder.add_temporary(w);
} }
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -140,34 +140,23 @@ void reduction_op(
const array& x, const array& x,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
U init, U init) {
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes); ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>(); auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>(); auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() { *out_ptr = init;
*out_ptr = init; contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
return; return;
} }
if (plan.type == ContiguousReduce && plan.shape.size() == 1) { if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0]; int reduction_size = plan.shape[0];
encoder.dispatch( for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable { *out_ptr = init;
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) { contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
*out_ptr = init; }
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
return; return;
} }
@ -178,40 +167,29 @@ void reduction_op(
// Unrolling the following loop (and implementing it in order for // Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost. // ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
encoder.dispatch([in_ptr, for (int i = 0; i < out.size(); i++, out_ptr++) {
out_ptr, int offset = elem_to_loc(i, shape, strides);
init, *out_ptr = init;
reduction_size, contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
} }
}); } else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
return; return;
} }
@ -220,20 +198,12 @@ void reduction_op(
size_t reduction_stride = plan.strides.back(); size_t reduction_stride = plan.strides.back();
plan.shape.pop_back(); plan.shape.pop_back();
plan.strides.pop_back(); plan.strides.pop_back();
for (int i = 0; i < out.size(); i += reduction_stride) {
encoder.dispatch([in_ptr, std::fill_n(out_ptr, reduction_stride, init);
out_ptr, strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
init, in_ptr += reduction_stride * reduction_size;
reduction_size, out_ptr += reduction_stride;
reduction_stride, }
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
return; return;
} }
@ -245,67 +215,49 @@ void reduction_op(
plan.strides.pop_back(); plan.strides.pop_back();
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr, if (plan.shape.size() == 0) {
out_ptr, for (int i = 0; i < out.size(); i += reduction_stride) {
init, int offset = elem_to_loc(i, shape, strides);
reduction_size, std::fill_n(out_ptr, reduction_stride, init);
reduction_stride, strided_reduce(
size = out.size(), in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
plan = std::move(plan), out_ptr += reduction_stride;
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
} }
}); } else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return; return;
} }
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
auto [shape, strides] = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr, for (int i = 0; i < out.size(); i++, out_ptr++) {
out_ptr, int offset = elem_to_loc(i, shape, strides);
init, U val = init;
size = out.size(), nd_loop(
plan = std::move(plan), [&](int extra_offset) {
shape = std::move(shape), val = Op{}(val, *(in_ptr + offset + extra_offset));
strides = std::move(strides)]() mutable { },
for (int i = 0; i < size; i++, out_ptr++) { plan.shape,
int offset = elem_to_loc(i, shape, strides); plan.strides);
U val = init; *out_ptr = val;
nd_loop( }
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
});
} }
} }
@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::And) { if (rtype == Reduce::And) {
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream); reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else { } else {
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream); reduction_op<InT, bool, OrReduce>(in, out, axes, false);
} }
} }
@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::Sum) { if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream); reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else { } else {
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream); reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
} }
} else { } else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) { if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream); reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else { } else {
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream); reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
} }
} }
} }
@ -470,162 +420,144 @@ void reduce_dispatch_min_max(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes, const std::vector<int>& axes) {
Stream stream) {
if (rtype == Reduce::Max) { if (rtype == Reduce::Max) {
auto init = Limits<InT>::min; auto init = Limits<InT>::min;
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream); reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else { } else {
auto init = Limits<InT>::max; auto init = Limits<InT>::max;
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream); reduction_op<InT, InT, MinReduce>(in, out, axes, init);
} }
} }
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (reduce_type_) { out.set_data(allocator::malloc_or_wait(out.nbytes()));
case Reduce::And: auto& encoder = cpu::get_command_encoder(stream());
case Reduce::Or: { encoder.set_input_array(in);
switch (in.dtype()) { encoder.set_output_array(out);
case bool_: encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint8: out = array::unsafe_weak_copy(out),
case int8: reduce_type_ = reduce_type_,
reduce_dispatch_and_or<int8_t>( axes_ = axes_]() mutable {
in, out, reduce_type_, axes_, stream()); switch (reduce_type_) {
break; case Reduce::And:
case int16: case Reduce::Or: {
case uint16: switch (in.dtype()) {
case float16: case bool_:
case bfloat16: case uint8:
reduce_dispatch_and_or<int16_t>( case int8:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break; break;
case uint32: case int16:
case int32: case uint16:
case float32: case float16:
reduce_dispatch_and_or<int32_t>( case bfloat16:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break; break;
case uint64: case uint32:
case int64: case int32:
case float64: case float32:
case complex64: reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int64_t>( break;
in, out, reduce_type_, axes_, stream()); case uint64:
break; case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
} }
break; case Reduce::Sum:
} case Reduce::Prod: {
case Reduce::Sum: switch (in.dtype()) {
case Reduce::Prod: { case bool_:
switch (in.dtype()) { case uint8:
case bool_: case int8:
case uint8: reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_sum_prod<int8_t>( case int16:
in, out, reduce_type_, axes_, stream()); case uint16:
break; reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
case int16: break;
case uint16: case int32:
reduce_dispatch_sum_prod<int16_t>( case uint32:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int64:
case uint32: case uint64:
reduce_dispatch_sum_prod<int32_t>( reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream()); break;
break; case float16:
case int64: reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_sum_prod<int64_t>( case bfloat16:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float32:
reduce_dispatch_sum_prod<float16_t>( reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream()); break;
break; case float64:
case bfloat16: reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<bfloat16_t>( break;
in, out, reduce_type_, axes_, stream()); case complex64:
break; reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_sum_prod<float>( }
in, out, reduce_type_, axes_, stream()); break;
break;
case float64:
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
} }
break; case Reduce::Max:
} case Reduce::Min: {
case Reduce::Max: switch (in.dtype()) {
case Reduce::Min: { case bool_:
switch (in.dtype()) { reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
case bool_: break;
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream()); case uint8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case uint8: break;
reduce_dispatch_min_max<uint8_t>( case uint16:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break; break;
case uint16: case uint32:
reduce_dispatch_min_max<uint16_t>( reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream()); break;
break; case uint64:
case uint32: reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint32_t>( break;
in, out, reduce_type_, axes_, stream()); case int8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_min_max<uint64_t>( case int16:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break; break;
case int8: case int32:
reduce_dispatch_min_max<uint8_t>( reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream()); break;
break; case int64:
case int16: reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>( break;
in, out, reduce_type_, axes_, stream()); case float16:
break; reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
case int32: break;
reduce_dispatch_min_max<int32_t>( case float32:
in, out, reduce_type_, axes_, stream()); reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break; break;
case int64: case float64:
reduce_dispatch_min_max<int64_t>( reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
in, out, reduce_type_, axes_, stream()); break;
break; case bfloat16:
case float16: reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float16_t>( break;
in, out, reduce_type_, axes_, stream()); case complex64:
break; reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_min_max<float>( }
in, out, reduce_type_, axes_, stream()); break;
break;
case float64:
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
} }
break;
} }
} });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -160,38 +160,29 @@ void scan_op(
bool reverse, bool reverse,
bool inclusive, bool inclusive,
const Op& op, const Op& op,
U init, U init) {
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) { if (in.strides()[axis] == 1) {
encoder.dispatch([in_ptr = in.data<T>(), contiguous_scan(
out_ptr = out.data<U>(), in.data<T>(),
count = in.size() / in.shape(axis), out.data<U>(),
stride = in.shape(axis), in.size() / in.shape(axis),
reverse, in.shape(axis),
inclusive, reverse,
op = std::move(op), inclusive,
init]() { op,
contiguous_scan( init);
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
});
} else { } else {
encoder.dispatch([in_ptr = in.data<T>(), strided_scan(
out_ptr = out.data<U>(), in.data<T>(),
count = in.size() / in.shape(axis) / in.strides()[axis], out.data<U>(),
size = in.shape(axis), in.size() / in.shape(axis) / in.strides()[axis],
stride = in.strides()[axis], in.shape(axis),
reverse, in.strides()[axis],
inclusive, reverse,
op = std::move(op), inclusive,
init]() { op,
strided_scan( init);
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
});
} }
} else { } else {
throw std::runtime_error("Scan op supports only contiguous inputs"); throw std::runtime_error("Scan op supports only contiguous inputs");
@ -205,19 +196,18 @@ void scan_dispatch(
array& out, array& out,
int axis, int axis,
bool reverse, bool reverse,
bool inclusive, bool inclusive) {
Stream stream) {
switch (rtype) { switch (rtype) {
case Scan::Sum: { case Scan::Sum: {
auto op = [](U y, T x) { return y + x; }; auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0); auto init = static_cast<U>(0);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Prod: { case Scan::Prod: {
auto op = [](U y, T x) { return y * x; }; auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1); auto init = static_cast<U>(1);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Min: { case Scan::Min: {
@ -225,7 +215,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity()) ? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::Max: { case Scan::Max: {
@ -233,7 +223,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating)) auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity()) ? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min(); : std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
} }
@ -244,88 +234,95 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) { void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
bool copied = false;
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream()); copy(in, arr_copy, CopyType::General, stream());
in = arr_copy; in = arr_copy;
copied = true; encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) { encoder.set_input_array(in);
case bool_: { encoder.set_output_array(out);
// We could do a full dtype x dtype switch but this is the only case encoder.dispatch([in = array::unsafe_weak_copy(in),
// where we accumulate in a different type, for now. out = array::unsafe_weak_copy(out),
// axis_ = axis_,
// TODO: If we add the option to accumulate floats in higher precision reduce_type_ = reduce_type_,
// floats perhaps we should add the full all-to-all dispatch. reverse_ = reverse_,
if (reduce_type_ == Scan::Sum && out.dtype() == int32) { inclusive_ = inclusive_]() mutable {
scan_dispatch<bool, int32_t>( switch (in.dtype()) {
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); case bool_: {
} else { // We could do a full dtype x dtype switch but this is the only case
scan_dispatch<bool, bool>( // where we accumulate in a different type, for now.
reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); //
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
} }
break; case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
} }
case uint8: });
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -16,51 +16,70 @@ void select_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
switch (out.dtype()) { Stream stream) {
case bool_: TernaryOpType topt = get_ternary_op_type(a, b, c);
ternary_op<bool, bool, bool, bool>(a, b, c, out, op); set_ternary_op_output_data(a, b, c, out, topt);
break;
case uint8: auto& encoder = cpu::get_command_encoder(stream);
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op); encoder.set_input_array(a);
break; encoder.set_input_array(b);
case uint16: encoder.set_input_array(c);
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op); encoder.set_output_array(out);
break; encoder.dispatch([a = array::unsafe_weak_copy(a),
case uint32: b = array::unsafe_weak_copy(b),
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op); c = array::unsafe_weak_copy(c),
break; out = array::unsafe_weak_copy(out),
case uint64: op,
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op); topt]() mutable {
break; switch (out.dtype()) {
case int8: case bool_:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op); ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break; break;
case int16: case uint8:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op); ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break; break;
case int32: case uint16:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op); ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break; break;
case int64: case uint32:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op); ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break; break;
case float16: case uint64:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op); ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break; break;
case float32: case int8:
ternary_op<bool, float, float, float>(a, b, c, out, op); ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break; break;
case float64: case int16:
ternary_op<bool, double, double, double>(a, b, c, out, op); ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break; break;
case bfloat16: case int32:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op); ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break; break;
case complex64: case int64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op); ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break; break;
} case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break;
}
});
} }
} // namespace } // namespace
@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0]; const auto& condition = inputs[0];
const auto& a = inputs[1]; const auto& a = inputs[1];
const auto& b = inputs[2]; const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select()); select_op(condition, a, b, out, detail::Select(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -105,15 +105,11 @@ struct StridedIterator {
}; };
template <typename T> template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) { void sort(array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@ -127,30 +123,20 @@ void sort(const array& in, array& out, int axis, Stream stream) {
// Perform sorting in place // Perform sorting in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto out_ptr = out.data<T>();
encoder.set_output_array(out); for (int i = 0; i < n_rows; i++) {
encoder.dispatch([out_ptr = out.data<T>(), T* data_ptr = out_ptr + src_it.loc;
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed); std::stable_sort(st, ed);
src_it.step(); src_it.step();
} }
});
} }
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis, Stream stream) { void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@ -176,99 +162,69 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto in_ptr = in.data<T>();
encoder.set_input_array(in); auto out_ptr = out.data<IdxT>();
encoder.set_input_array(out); for (int i = 0; i < n_rows; i++) {
encoder.dispatch([in_ptr = in.data<T>(), const T* data_ptr = in_ptr + in_it.loc;
out_ptr = out.data<IdxT>(), IdxT* idx_ptr = out_ptr + out_it.loc;
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step(); in_it.step();
out_it.step(); out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size); StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0); StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size); StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
});
} }
template <typename T> template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) { void partition(array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t in_size = out.size();
size_t n_rows = in_size / in.shape(axis); size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis]; auto axis_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto out_ptr = out.data<T>();
encoder.set_output_array(out); for (int i = 0; i < n_rows; i++) {
encoder.dispatch([out_ptr = out.data<T>(), T* data_ptr = out_ptr + src_it.loc;
src_it = std::move(src_it), src_it.step();
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed); std::nth_element(st, md, ed);
} }
});
} }
template <typename T, typename IdxT = uint32_t> template <typename T, typename IdxT = uint32_t>
void argpartition( void argpartition(const array& in, array& out, int axis, int kth) {
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
@ -297,42 +253,32 @@ void argpartition(
ContiguousIterator out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream); auto in_ptr = in.data<T>();
encoder.set_input_array(in); auto out_ptr = out.data<IdxT>();
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0); for (int i = 0; i < n_rows; i++) {
StridedIterator ed_(idx_ptr, out_stride, axis_size); const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
// Initialize with iota StridedIterator st_(idx_ptr, out_stride, 0);
std::iota(st_, ed_, IdxT(0)); StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals // Initialize with iota
StridedIterator st(idx_ptr, out_stride, 0); std::iota(st_, ed_, IdxT(0));
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { // Sort according to vals
auto v1 = data_ptr[a * in_stride]; StridedIterator st(idx_ptr, out_stride, 0);
auto v2 = data_ptr[b * in_stride]; StridedIterator md(idx_ptr, out_stride, kth);
return v1 < v2 || (v1 == v2 && a < b); StridedIterator ed(idx_ptr, out_stride, axis_size);
});
} std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
}); auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
} }
} // namespace } // namespace
@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Allocate output
case bool_: out.set_data(allocator::malloc_or_wait(out.nbytes()));
return argsort<bool>(in, out, axis_, stream());
case uint8: auto& encoder = cpu::get_command_encoder(stream());
return argsort<uint8_t>(in, out, axis_, stream()); encoder.set_input_array(in);
case uint16: encoder.set_input_array(out);
return argsort<uint16_t>(in, out, axis_, stream()); encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint32: out = array::unsafe_weak_copy(out),
return argsort<uint32_t>(in, out, axis_, stream()); axis_ = axis_]() mutable {
case uint64: switch (in.dtype()) {
return argsort<uint64_t>(in, out, axis_, stream()); case bool_:
case int8: return argsort<bool>(in, out, axis_);
return argsort<int8_t>(in, out, axis_, stream()); case uint8:
case int16: return argsort<uint8_t>(in, out, axis_);
return argsort<int16_t>(in, out, axis_, stream()); case uint16:
case int32: return argsort<uint16_t>(in, out, axis_);
return argsort<int32_t>(in, out, axis_, stream()); case uint32:
case int64: return argsort<uint32_t>(in, out, axis_);
return argsort<int64_t>(in, out, axis_, stream()); case uint64:
case float32: return argsort<uint64_t>(in, out, axis_);
return argsort<float>(in, out, axis_, stream()); case int8:
case float64: return argsort<int8_t>(in, out, axis_);
return argsort<double>(in, out, axis_, stream()); case int16:
case float16: return argsort<int16_t>(in, out, axis_);
return argsort<float16_t>(in, out, axis_, stream()); case int32:
case bfloat16: return argsort<int32_t>(in, out, axis_);
return argsort<bfloat16_t>(in, out, axis_, stream()); case int64:
case complex64: return argsort<int64_t>(in, out, axis_);
return argsort<complex64_t>(in, out, axis_, stream()); case float32:
} return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
});
} }
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) { void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
case bool_: CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
return sort<bool>(in, out, axis_, stream()); copy(in, out, ctype, stream());
case uint8:
return sort<uint8_t>(in, out, axis_, stream()); auto& encoder = cpu::get_command_encoder(stream());
case uint16: encoder.set_output_array(out);
return sort<uint16_t>(in, out, axis_, stream()); encoder.dispatch(
case uint32: [out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
return sort<uint32_t>(in, out, axis_, stream()); switch (out.dtype()) {
case uint64: case bool_:
return sort<uint64_t>(in, out, axis_, stream()); return sort<bool>(out, axis_);
case int8: case uint8:
return sort<int8_t>(in, out, axis_, stream()); return sort<uint8_t>(out, axis_);
case int16: case uint16:
return sort<int16_t>(in, out, axis_, stream()); return sort<uint16_t>(out, axis_);
case int32: case uint32:
return sort<int32_t>(in, out, axis_, stream()); return sort<uint32_t>(out, axis_);
case int64: case uint64:
return sort<int64_t>(in, out, axis_, stream()); return sort<uint64_t>(out, axis_);
case float32: case int8:
return sort<float>(in, out, axis_, stream()); return sort<int8_t>(out, axis_);
case float64: case int16:
return sort<double>(in, out, axis_, stream()); return sort<int16_t>(out, axis_);
case float16: case int32:
return sort<float16_t>(in, out, axis_, stream()); return sort<int32_t>(out, axis_);
case bfloat16: case int64:
return sort<bfloat16_t>(in, out, axis_, stream()); return sort<int64_t>(out, axis_);
case complex64: case float32:
return sort<complex64_t>(in, out, axis_, stream()); return sort<float>(out, axis_);
} case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
} }
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Allocate output
case bool_: out.set_data(allocator::malloc_or_wait(out.nbytes()));
return argpartition<bool>(in, out, axis_, kth_, stream());
case uint8: auto& encoder = cpu::get_command_encoder(stream());
return argpartition<uint8_t>(in, out, axis_, kth_, stream()); encoder.set_input_array(in);
case uint16: encoder.set_input_array(out);
return argpartition<uint16_t>(in, out, axis_, kth_, stream()); encoder.dispatch([in = array::unsafe_weak_copy(in),
case uint32: out = array::unsafe_weak_copy(out),
return argpartition<uint32_t>(in, out, axis_, kth_, stream()); axis_ = axis_,
case uint64: kth_ = kth_]() mutable {
return argpartition<uint64_t>(in, out, axis_, kth_, stream()); switch (in.dtype()) {
case int8: case bool_:
return argpartition<int8_t>(in, out, axis_, kth_, stream()); return argpartition<bool>(in, out, axis_, kth_);
case int16: case uint8:
return argpartition<int16_t>(in, out, axis_, kth_, stream()); return argpartition<uint8_t>(in, out, axis_, kth_);
case int32: case uint16:
return argpartition<int32_t>(in, out, axis_, kth_, stream()); return argpartition<uint16_t>(in, out, axis_, kth_);
case int64: case uint32:
return argpartition<int64_t>(in, out, axis_, kth_, stream()); return argpartition<uint32_t>(in, out, axis_, kth_);
case float32: case uint64:
return argpartition<float>(in, out, axis_, kth_, stream()); return argpartition<uint64_t>(in, out, axis_, kth_);
case float64: case int8:
return argpartition<double>(in, out, axis_, kth_, stream()); return argpartition<int8_t>(in, out, axis_, kth_);
case float16: case int16:
return argpartition<float16_t>(in, out, axis_, kth_, stream()); return argpartition<int16_t>(in, out, axis_, kth_);
case bfloat16: case int32:
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream()); return argpartition<int32_t>(in, out, axis_, kth_);
case complex64: case int64:
return argpartition<complex64_t>(in, out, axis_, kth_, stream()); return argpartition<int64_t>(in, out, axis_, kth_);
} case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
} }
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) { void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { // Copy input to output
case bool_: CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
return partition<bool>(in, out, axis_, kth_, stream()); copy(in, out, ctype, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_, stream()); auto& encoder = cpu::get_command_encoder(stream());
case uint16: encoder.set_output_array(out);
return partition<uint16_t>(in, out, axis_, kth_, stream()); encoder.dispatch([out = array::unsafe_weak_copy(out),
case uint32: axis_ = axis_,
return partition<uint32_t>(in, out, axis_, kth_, stream()); kth_ = kth_]() mutable {
case uint64: switch (out.dtype()) {
return partition<uint64_t>(in, out, axis_, kth_, stream()); case bool_:
case int8: return partition<bool>(out, axis_, kth_);
return partition<int8_t>(in, out, axis_, kth_, stream()); case uint8:
case int16: return partition<uint8_t>(out, axis_, kth_);
return partition<int16_t>(in, out, axis_, kth_, stream()); case uint16:
case int32: return partition<uint16_t>(out, axis_, kth_);
return partition<int32_t>(in, out, axis_, kth_, stream()); case uint32:
case int64: return partition<uint32_t>(out, axis_, kth_);
return partition<int64_t>(in, out, axis_, kth_, stream()); case uint64:
case float32: return partition<uint64_t>(out, axis_, kth_);
return partition<float>(in, out, axis_, kth_, stream()); case int8:
case float64: return partition<int8_t>(out, axis_, kth_);
return partition<double>(in, out, axis_, kth_, stream()); case int16:
case float16: return partition<int16_t>(out, axis_, kth_);
return partition<float16_t>(in, out, axis_, kth_, stream()); case int32:
case bfloat16: return partition<int32_t>(out, axis_, kth_);
return partition<bfloat16_t>(in, out, axis_, kth_, stream()); case int64:
case complex64: return partition<int64_t>(out, axis_, kth_);
return partition<complex64_t>(in, out, axis_, kth_, stream()); case float32:
} return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(out, axis_, kth_);
}
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,12 +1,10 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@ -128,57 +126,28 @@ void ternary_op(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
Op op) { Op op,
TernaryOpType topt = get_ternary_op_type(a, b, c); TernaryOpType topt) {
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
const T1* a_ptr = a.data<T1>(); const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>(); const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>(); const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
encoder.dispatch( *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
});
} else if (topt == TernaryOpType::VectorVectorVector) { } else if (topt == TernaryOpType::VectorVectorVector) {
encoder.dispatch([a_ptr, for (size_t i = 0; i < out.size(); ++i) {
b_ptr, *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
c_ptr, a_ptr++;
out_ptr, b_ptr++;
op = std::move(op), c_ptr++;
size = out.size()]() mutable { out_ptr++;
for (size_t i = 0; i < size; ++i) { }
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
});
} else { } else {
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
encoder.dispatch( ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
[a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size(),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
});
} }
} }

View File

@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto op = detail::Abs{}; unary_signed(in, out, detail::Abs(), stream());
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
} }
} }
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos()); unary_fp(in, out, detail::ArcCos(), stream());
} }
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh()); unary_fp(in, out, detail::ArcCosh(), stream());
} }
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin()); unary_fp(in, out, detail::ArcSin(), stream());
} }
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh()); unary_fp(in, out, detail::ArcSinh(), stream());
} }
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan()); unary_fp(in, out, detail::ArcTan(), stream());
} }
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) { void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh()); unary_fp(in, out, detail::ArcTanh(), stream());
} }
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert()); unary_int(in, out, detail::BitwiseInvert(), stream());
} }
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) { void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil()); unary_fp(in, out, detail::Ceil(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) { void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate()); unary_complex(inputs[0], out, detail::Conjugate(), stream());
} }
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) { void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cos()); unary_fp(in, out, detail::Cos(), stream());
} }
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) { void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh()); unary_fp(in, out, detail::Cosh(), stream());
} }
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) { void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::Erf(), stream());
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
} }
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) { void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { unary_real_fp(in, out, detail::ErfInv(), stream());
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
} }
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) { void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Exp()); unary_fp(in, out, detail::Exp(), stream());
} }
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) { void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1()); unary_fp(in, out, detail::Expm1(), stream());
} }
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) { void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor()); unary_fp(in, out, detail::Floor(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) { void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag()); unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
} }
void Log::eval_cpu(const std::vector<array>& inputs, array& out) { void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_fp(in, out, detail::Log()); unary_fp(in, out, detail::Log(), stream());
break; break;
case Base::two: case Base::two:
unary_fp(in, out, detail::Log2()); unary_fp(in, out, detail::Log2(), stream());
break; break;
case Base::ten: case Base::ten:
unary_fp(in, out, detail::Log10()); unary_fp(in, out, detail::Log10(), stream());
break; break;
} }
} }
@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) { void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p()); unary_fp(in, out, detail::Log1p(), stream());
} }
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::LogicalNot()); unary(in, out, detail::LogicalNot(), stream());
} }
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) { void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Negative()); unary(in, out, detail::Negative(), stream());
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) { void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real()); unary_complex_to_float(inputs[0], out, detail::Real(), stream());
} }
void Round::eval_cpu(const std::vector<array>& inputs, array& out) { void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round()); unary_fp(in, out, detail::Round(), stream());
} else { } else {
// No-op integer types // No-op integer types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) { void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid()); unary_fp(in, out, detail::Sigmoid(), stream());
} }
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) { void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
unary(in, out, detail::Sign()); unary(in, out, detail::Sign(), stream());
} }
} }
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) { void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sin()); unary_fp(in, out, detail::Sin(), stream());
} }
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) { void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh()); unary_fp(in, out, detail::Sinh(), stream());
} }
void Square::eval_cpu(const std::vector<array>& inputs, array& out) { void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
unary(in, out, detail::Square()); unary(in, out, detail::Square(), stream());
} }
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) { void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (recip_) { if (recip_) {
unary_fp(in, out, detail::Rsqrt()); unary_fp(in, out, detail::Rsqrt(), stream());
} else { } else {
unary_fp(in, out, detail::Sqrt()); unary_fp(in, out, detail::Sqrt(), stream());
} }
} }
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) { void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tan()); unary_fp(in, out, detail::Tan(), stream());
} }
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) { void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh()); unary_fp(in, out, detail::Tanh(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,7 +7,6 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@ -39,156 +38,263 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op) { void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>(); const T* src = a.data<T>();
U* dst = out.data<U>(); U* dst = out.data<U>();
auto& encoder = cpu::get_command_encoder(out.primitive().stream()); auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N;
src += N;
dst += N;
}
while (size > 0) {
*dst = Op{}(*src);
size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
}
template <typename Op>
void unary(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
encoder.dispatch([src, out = array::unsafe_weak_copy(out),
dst, op = op]() mutable {
contig = a.flags().contiguous, switch (out.dtype()) {
data_size = a.data_size(), case bool_:
size = a.size(), unary_op<bool>(a, out, op);
shapes = a.shape(), break;
strides = a.strides()]() mutable { case uint8:
auto ndim = shapes.size(); unary_op<uint8_t>(a, out, op);
if (contig) { break;
constexpr int N = simd::max_size<T>; case uint16:
while (data_size >= N) { unary_op<uint16_t>(a, out, op);
simd::store(dst, Op{}(simd::load<T, N>(src))); break;
data_size -= N; case uint32:
src += N; unary_op<uint32_t>(a, out, op);
dst += N; break;
} case uint64:
while (data_size > 0) { unary_op<uint64_t>(a, out, op);
*dst = Op{}(*src); break;
data_size--; case int8:
dst++; unary_op<int8_t>(a, out, op);
src++; break;
} case int16:
} else { unary_op<int16_t>(a, out, op);
size_t shape = ndim > 0 ? shapes.back() : 1; break;
size_t stride = ndim > 0 ? strides.back() : 1; case int32:
if (ndim <= 1) { unary_op<int32_t>(a, out, op);
unary_op<T, U, Op>(src, dst, shape, stride); break;
return; case int64:
} unary_op<int64_t>(a, out, op);
auto it = ContiguousIterator(shapes, strides, ndim - 1); break;
for (size_t elem = 0; elem < size; elem += shape) { case float16:
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride); unary_op<float16_t>(a, out, op);
it.step(); break;
} case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
} }
}); });
} }
template <typename Op> template <typename Op>
void unary(const array& a, array& out, Op op) { void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case bool_: auto& encoder = cpu::get_command_encoder(stream);
unary_op<bool>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case uint8: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<uint8_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable {
case uint16: switch (out.dtype()) {
unary_op<uint16_t>(a, out, op); case bfloat16:
break; unary_op<bfloat16_t>(a, out, op);
case uint32: break;
unary_op<uint32_t>(a, out, op); case float16:
break; unary_op<float16_t>(a, out, op);
case uint64: break;
unary_op<uint64_t>(a, out, op); case float32:
break; unary_op<float>(a, out, op);
case int8: break;
unary_op<int8_t>(a, out, op); case float64:
break; unary_op<double>(a, out, op);
case int16: break;
unary_op<int16_t>(a, out, op); default:
break; std::ostringstream err;
case int32: err << "[unary_real] Does not support " << out.dtype();
unary_op<int32_t>(a, out, op); throw std::runtime_error(err.str());
break; }
case int64: });
unary_op<int64_t>(a, out, op); }
break; template <typename Op>
case float16: void unary_fp(const array& a, array& out, Op op, Stream stream) {
unary_op<float16_t>(a, out, op); set_unary_output_data(a, out);
break; auto& encoder = cpu::get_command_encoder(stream);
case float32: encoder.set_input_array(a);
unary_op<float>(a, out, op); encoder.set_output_array(out);
break; encoder.dispatch([a = array::unsafe_weak_copy(a),
case float64: out = array::unsafe_weak_copy(out),
unary_op<double>(a, out, op); op = op]() mutable {
break; switch (out.dtype()) {
case bfloat16: case bfloat16:
unary_op<bfloat16_t>(a, out, op); unary_op<bfloat16_t>(a, out, op);
break; break;
case complex64: case float16:
unary_op<complex64_t>(a, out, op); unary_op<float16_t>(a, out, op);
break; break;
} case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
} }
template <typename Op> template <typename Op>
void unary_fp(const array& a, array& out, Op op) { void unary_signed(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case bfloat16: auto& encoder = cpu::get_command_encoder(stream);
unary_op<bfloat16_t>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case float16: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<float16_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable {
case float32: switch (out.dtype()) {
unary_op<float>(a, out, op); case int8:
break; unary_op<int8_t>(a, out, op);
case float64: break;
unary_op<double>(a, out, op); case int16:
break; unary_op<int16_t>(a, out, op);
case complex64: break;
unary_op<complex64_t>(a, out, op); case int32:
break; unary_op<int32_t>(a, out, op);
default: break;
std::ostringstream err; case int64:
err << "[unary_fp] Does not support " << out.dtype(); unary_op<int64_t>(a, out, op);
throw std::runtime_error(err.str()); break;
} case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
} }
template <typename Op> template <typename Op>
void unary_int(const array& a, array& out, Op op) { void unary_complex(const array& a, array& out, Op op, Stream stream) {
switch (out.dtype()) { set_unary_output_data(a, out);
case uint8: auto& encoder = cpu::get_command_encoder(stream);
unary_op<uint8_t>(a, out, op); encoder.set_input_array(a);
break; encoder.set_output_array(out);
case uint16: encoder.dispatch([a = array::unsafe_weak_copy(a),
unary_op<uint16_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable { unary_op<complex64_t>(a, out, op); });
case uint32: }
unary_op<uint32_t>(a, out, op);
break; template <typename Op>
case uint64: void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
unary_op<uint64_t>(a, out, op); set_unary_output_data(a, out);
break; auto& encoder = cpu::get_command_encoder(stream);
case int8: encoder.set_input_array(a);
unary_op<int8_t>(a, out, op); encoder.set_output_array(out);
break; encoder.dispatch(
case int16: [a = array::unsafe_weak_copy(a),
unary_op<int16_t>(a, out, op); out = array::unsafe_weak_copy(out),
break; op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
case int32: }
unary_op<int32_t>(a, out, op);
break; template <typename Op>
case int64: void unary_int(const array& a, array& out, Op op, Stream stream) {
unary_op<int64_t>(a, out, op); set_unary_output_data(a, out);
break; auto& encoder = cpu::get_command_encoder(stream);
default: encoder.set_input_array(a);
std::ostringstream err; encoder.set_output_array(out);
err << "[unary_int] Does not support " << out.dtype(); encoder.dispatch([a = array::unsafe_weak_copy(a),
throw std::runtime_error(err.str()); out = array::unsafe_weak_copy(out),
} op = op]() mutable {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
} }
} // namespace mlx::core } // namespace mlx::core