diff --git a/mlx/array.cpp b/mlx/array.cpp index 406a491e65..d8d12e0db4 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -56,6 +56,18 @@ std::vector array::make_arrays( return outputs; } +array array::unsafe_weak_copy(const array& other) { + auto cpy = array(other.shape(), other.dtype(), nullptr, {}); + cpy.set_data( + other.buffer(), + other.data_size(), + other.strides(), + other.flags(), + [](auto) {}); + cpy.array_desc_->data_ptr = other.array_desc_->data_ptr; + return cpy; +} + array::array(std::initializer_list data) : array_desc_(std::make_shared( Shape{static_cast(data.size())}, diff --git a/mlx/array.h b/mlx/array.h index 5e980d532e..d690dcd979 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -199,6 +199,13 @@ class array { const std::shared_ptr& primitive, const std::vector& inputs); + /** + * Get a new array that refers to the same data as the input but with a + * non-owning pointer to it. Note the array is detached from the graph and has + * no inputs, siblings or primitive. + */ + static array unsafe_weak_copy(const array& other); + /** A unique identifier for an array. */ std::uintptr_t id() const { return reinterpret_cast(array_desc_.get()); diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 96e7f9ee97..c9bdc35b0e 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -11,12 +11,7 @@ namespace mlx::core { namespace { template -void arg_reduce( - const array& in, - array& out, - const OpT& op, - int axis, - Stream stream) { +void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; Strides strides = in.strides(); @@ -26,28 +21,16 @@ void arg_reduce( auto in_ptr = in.data(); auto out_ptr = out.data(); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(in); - encoder.set_output_array(out); - encoder.dispatch([in_ptr, - out_ptr, - axis_size, - axis_stride, - op = std::move(op), - shape = std::move(shape), - strides = std::move(strides), - size = out.size()]() { - for (uint32_t i = 0; i < size; ++i) { - auto loc = elem_to_loc(i, shape, strides); - auto local_in_ptr = in_ptr + loc; - uint32_t ind_v = 0; - InT v = (*local_in_ptr); - for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { - op(j, (*local_in_ptr), &ind_v, &v); - } - out_ptr[i] = ind_v; + for (uint32_t i = 0; i < out.size(); ++i) { + auto loc = elem_to_loc(i, shape, strides); + auto local_in_ptr = in_ptr + loc; + uint32_t ind_v = 0; + InT v = (*local_in_ptr); + for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { + op(j, (*local_in_ptr), &ind_v, &v); } - }); + out_ptr[i] = ind_v; + } } template @@ -55,8 +38,7 @@ void arg_reduce_dispatch( const array& in, array& out, ArgReduce::ReduceType rtype, - int axis, - Stream stream) { + int axis) { switch (rtype) { case ArgReduce::ArgMin: { auto op = [](auto ind_x, auto x, auto ind_y, auto y) { @@ -65,7 +47,7 @@ void arg_reduce_dispatch( (*ind_y) = ind_x; } }; - arg_reduce(in, out, op, axis, stream); + arg_reduce(in, out, op, axis); break; } case ArgReduce::ArgMax: { @@ -75,7 +57,7 @@ void arg_reduce_dispatch( (*ind_y) = ind_x; } }; - arg_reduce(in, out, op, axis, stream); + arg_reduce(in, out, op, axis); break; } } @@ -87,51 +69,58 @@ void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); - - switch (in.dtype()) { - case bool_: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case uint8: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case uint16: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case uint32: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case uint64: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case int8: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case int16: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case int32: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case int64: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case float16: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case float32: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case bfloat16: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case float64: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - case complex64: - arg_reduce_dispatch(in, out, reduce_type_, axis_, stream()); - break; - } + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + reduce_type_ = reduce_type_, + axis_ = axis_]() mutable { + switch (in.dtype()) { + case bool_: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint8: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int8: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case float16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case float32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case bfloat16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case float64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case complex64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 55c4e69f44..dbdab6a06a 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/cpu/binary.h" #include "mlx/backend/cpu/binary_ops.h" #include "mlx/backend/cpu/binary_two.h" +#include "mlx/backend/cpu/encoder.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -16,51 +17,218 @@ namespace mlx::core { namespace { template -void comparison_op(const array& a, const array& b, array& out) { - switch (a.dtype()) { - case bool_: - binary_op(a, b, out); - break; - case uint8: - binary_op(a, b, out); - break; - case uint16: - binary_op(a, b, out); - break; - case uint32: - binary_op(a, b, out); - break; - case uint64: - binary_op(a, b, out); - break; - case int8: - binary_op(a, b, out); - break; - case int16: - binary_op(a, b, out); - break; - case int32: - binary_op(a, b, out); - break; - case int64: - binary_op(a, b, out); - break; - case float16: - binary_op(a, b, out); - break; - case float32: - binary_op(a, b, out); - break; - case float64: - binary_op(a, b, out); - break; - case bfloat16: - binary_op(a, b, out); - break; - case complex64: - binary_op(a, b, out); - break; - } +void binary(const array& a, const array& b, array& out, Op op, Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void comparison_op( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (a.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void binary_float( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error( + "[binary_float] Only supports non-complex floating point types."); + } + }); +} + +template +void binary_int( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error("[binary_int] Type not supported"); + break; + } + }); } } // namespace @@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Add()); + binary(a, b, out, detail::Add(), stream()); } void DivMod::eval_cpu( @@ -78,70 +246,89 @@ void DivMod::eval_cpu( assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - auto integral_op = [](auto x, auto y) { - return std::make_pair(x / y, x % y); - }; - auto float_op = [](auto x, auto y) { - return std::make_pair(std::trunc(x / y), std::fmod(x, y)); - }; - switch (outputs[0].dtype()) { - case bool_: - binary_op(a, b, outputs, integral_op); - case uint8: - binary_op(a, b, outputs, integral_op); - break; - case uint16: - binary_op(a, b, outputs, integral_op); - break; - case uint32: - binary_op(a, b, outputs, integral_op); - break; - case uint64: - binary_op(a, b, outputs, integral_op); - break; - case int8: - binary_op(a, b, outputs, integral_op); - break; - case int16: - binary_op(a, b, outputs, integral_op); - break; - case int32: - binary_op(a, b, outputs, integral_op); - break; - case int64: - binary_op(a, b, outputs, integral_op); - break; - case float16: - binary_op(a, b, outputs, float_op); - break; - case float32: - binary_op(a, b, outputs, float_op); - break; - case float64: - binary_op(a, b, outputs, float_op); - break; - case bfloat16: - binary_op(a, b, outputs, float_op); - break; - case complex64: - // Should never get here - throw std::runtime_error("[DivMod] Complex type not supported"); - break; - } + auto bopt = get_binary_op_type(a, b); + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; + set_binary_op_output_data(a, b, out_a, bopt); + set_binary_op_output_data(a, b, out_b, bopt); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out_a = array::unsafe_weak_copy(out_a), + out_b = array::unsafe_weak_copy(out_b), + bopt]() mutable { + auto integral_op = [](auto x, auto y) { + return std::make_pair(x / y, x % y); + }; + auto float_op = [](auto x, auto y) { + return std::make_pair(std::trunc(x / y), std::fmod(x, y)); + }; + + switch (out_a.dtype()) { + case bool_: + binary_op(a, b, out_a, out_b, integral_op, bopt); + case uint8: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case uint16: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case uint32: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case uint64: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case int8: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case int16: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case int32: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case int64: + binary_op(a, b, out_a, out_b, integral_op, bopt); + break; + case float16: + binary_op(a, b, out_a, out_b, float_op, bopt); + break; + case float32: + binary_op(a, b, out_a, out_b, float_op, bopt); + break; + case float64: + binary_op(a, b, out_a, out_b, float_op, bopt); + break; + case bfloat16: + binary_op(a, b, out_a, out_b, float_op, bopt); + break; + case complex64: + // Should never get here + throw std::runtime_error("[DivMod] Complex type not supported"); + break; + } + }); } void Divide::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Divide()); + binary(a, b, out, detail::Divide(), stream()); } void Remainder::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Remainder()); + binary(a, b, out, detail::Remainder(), stream()); } void Equal::eval_cpu(const std::vector& inputs, array& out) { @@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector& inputs, array& out) { auto& a = inputs[0]; auto& b = inputs[1]; if (equal_nan_) { - switch (a.dtype()) { - case float16: - binary_op(a, b, out); - break; - case float32: - binary_op(a, b, out); - break; - case float64: - binary_op(a, b, out); - break; - case bfloat16: - binary_op(a, b, out); - break; - case complex64: - binary_op(a, b, out); - break; - default: - throw std::runtime_error( - "[NanEqual::eval_cpu] Only for floating point types."); - } + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (a.dtype()) { + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error( + "[NanEqual::eval_cpu] Only for floating point types."); + } + }); } else { - comparison_op(a, b, out); + comparison_op(a, b, out, detail::Equal(), stream()); } } void Greater::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out); + comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream()); } void GreaterEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out); + comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream()); } void Less::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out); + comparison_op(inputs[0], inputs[1], out, detail::Less(), stream()); } void LessEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out); + comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream()); } void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - switch (out.dtype()) { - case float16: - binary_op(a, b, out); - break; - case float32: - binary_op(a, b, out); - break; - case float64: - binary_op(a, b, out); - break; - case bfloat16: - binary_op(a, b, out); - break; - default: - throw std::runtime_error( - "[LogAddExp::eval_cpu] Only supports non-complex floating point types."); - } + binary_float(a, b, out, detail::LogAddExp(), stream()); } void LogicalAnd::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalAnd()); + binary(in1, in2, out, detail::LogicalAnd(), stream()); } void LogicalOr::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalOr()); + binary(in1, in2, out, detail::LogicalOr(), stream()); } void Maximum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Maximum()); + binary(a, b, out, detail::Maximum(), stream()); } void Minimum::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Minimum()); + binary(a, b, out, detail::Minimum(), stream()); } void Multiply::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Multiply()); + binary(a, b, out, detail::Multiply(), stream()); } void NotEqual::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op(inputs[0], inputs[1], out); + comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream()); } void Power::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Power()); + binary(a, b, out, detail::Power(), stream()); } void Subtract::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, detail::Subtract()); + binary(a, b, out, detail::Subtract(), stream()); } void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - auto dispatch_type = [&a, &b, &out](auto op) { - switch (out.dtype()) { - case bool_: - binary_op(a, b, out, op); - case uint8: - binary_op(a, b, out, op); - break; - case uint16: - binary_op(a, b, out, op); - break; - case uint32: - binary_op(a, b, out, op); - break; - case uint64: - binary_op(a, b, out, op); - break; - case int8: - binary_op(a, b, out, op); - break; - case int16: - binary_op(a, b, out, op); - break; - case int32: - binary_op(a, b, out, op); - break; - case int64: - binary_op(a, b, out, op); - break; - default: - throw std::runtime_error( - "[BitwiseBinary::eval_cpu] Type not supported"); - break; - } - }; switch (op_) { case BitwiseBinary::And: - dispatch_type(detail::BitwiseAnd()); + binary_int(a, b, out, detail::BitwiseAnd(), stream()); break; case BitwiseBinary::Or: - dispatch_type(detail::BitwiseOr()); + binary_int(a, b, out, detail::BitwiseOr(), stream()); break; case BitwiseBinary::Xor: - dispatch_type(detail::BitwiseXor()); + binary_int(a, b, out, detail::BitwiseXor(), stream()); break; case BitwiseBinary::LeftShift: - dispatch_type(detail::LeftShift()); + binary_int(a, b, out, detail::LeftShift(), stream()); break; case BitwiseBinary::RightShift: - dispatch_type(detail::RightShift()); + binary_int(a, b, out, detail::RightShift(), stream()); break; } } @@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); const auto& a = inputs[0]; const auto& b = inputs[1]; - switch (out.dtype()) { - case float16: - binary_op(a, b, out, detail::ArcTan2()); - break; - case float32: - binary_op(a, b, out, detail::ArcTan2()); - break; - case float64: - binary_op(a, b, out, detail::ArcTan2()); - break; - case bfloat16: - binary_op(a, b, out, detail::ArcTan2()); - break; - default: - throw std::runtime_error( - "[ArcTan2::eval_cpu] Only supports non-complex floating point types."); - } + binary_float(a, b, out, detail::ArcTan2(), stream()); } } // namespace mlx::core diff --git a/mlx/backend/cpu/binary.h b/mlx/backend/cpu/binary.h index 623f1910af..31966c0ea3 100644 --- a/mlx/backend/cpu/binary.h +++ b/mlx/backend/cpu/binary.h @@ -3,12 +3,9 @@ #pragma once #include -#include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/cpu/encoder.h" -#include "mlx/primitives.h" #include "mlx/backend/cpu/simd/simd.h" @@ -152,218 +149,145 @@ void binary_op_dispatch_dims( } template -void binary_op(const array& a, const array& b, array& out) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { // The full computation is scalar scalar so call the base op once auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_ptr = out.data(); - auto& encoder = cpu::get_command_encoder(out.primitive().stream()); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - encoder.dispatch([bopt, - a_ptr, - b_ptr, - out_ptr, - a_data_size = a.data_size(), - b_data_size = b.data_size(), - size = a.size(), - shape = a.shape(), - a_strides = a.strides(), - b_strides = b.strides(), - strides = out.strides()]() mutable { - if (bopt == BinaryOpType::ScalarScalar) { - *out_ptr = Op{}(*a_ptr, *b_ptr); - return; + if (bopt == BinaryOpType::ScalarScalar) { + *out_ptr = Op{}(*a_ptr, *b_ptr); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == BinaryOpType::ScalarVector) { + ScalarVector{}(a_ptr, b_ptr, out_ptr, b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == BinaryOpType::VectorScalar) { + VectorScalar{}(a_ptr, b_ptr, out_ptr, a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == BinaryOpType::VectorVector) { + VectorVector{}(a_ptr, b_ptr, out_ptr, a.size()); + return; + } + + // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out.strides()}); + auto& a_strides = new_strides[0]; + auto& b_strides = new_strides[1]; + auto& strides = new_strides[2]; + + // Get the left-most dim such that the array is row contiguous after + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); - // The full computation is scalar vector so delegate to the op - if (bopt == BinaryOpType::ScalarVector) { - ScalarVector{}(a_ptr, b_ptr, out_ptr, b_data_size); - return; + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); - // The full computation is vector scalar so delegate to the op - if (bopt == BinaryOpType::VectorScalar) { - VectorScalar{}(a_ptr, b_ptr, out_ptr, a_data_size); - return; - } + auto ndim = new_shape.size(); - // The full computation is vector vector so delegate to the op - if (bopt == BinaryOpType::VectorVector) { - VectorVector{}(a_ptr, b_ptr, out_ptr, size); - return; - } - - // General computation so let's try to optimize - auto [new_shape, new_strides] = collapse_contiguous_dims( - shape, - {std::move(a_strides), std::move(b_strides), std::move(strides)}); - a_strides = new_strides[0]; - b_strides = new_strides[1]; - strides = new_strides[2]; - - // Get the left-most dim such that the array is row contiguous after - auto leftmost_rc_dim = [&strides](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == strides[d]; d--) { - } - return d + 1; - }; - auto a_rc_dim = leftmost_rc_dim(a_strides); - auto b_rc_dim = leftmost_rc_dim(b_strides); - - // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == 0; d--) { - } - return d + 1; - }; - auto a_s_dim = leftmost_s_dim(a_strides); - auto b_s_dim = leftmost_s_dim(b_strides); - - auto ndim = new_shape.size(); - - // Case 1: LxM and FxM where L and F are broadcastable and M is row + // Case 1: LxM and FxM where L and F are broadcastable and M is row + // contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous - int dim = ndim; - if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::VectorVector; - dim = d; - // Case 2: LxM and Fx1 where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = BinaryOpType::VectorScalar; - dim = d; - // Case 3: Lx1 and FxM where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::ScalarVector; - dim = d; - } + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = BinaryOpType::VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::ScalarVector; + dim = d; + } - // Can be sure dim > 0 since otherwise we would have used one of the fully - // contiguous methods above. Except for the case that the flags do not - // correspond to the underlying contiguity. - if (dim == 0 || strides[dim - 1] < 16) { - bopt = BinaryOpType::General; - dim = ndim; - } + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + if (dim == 0 || strides[dim - 1] < 16) { + bopt = BinaryOpType::General; + dim = ndim; + } - switch (bopt) { - case BinaryOpType::VectorVector: - binary_op_dispatch_dims>( - a_ptr, - b_ptr, - out_ptr, - dim, - size, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::VectorScalar: - binary_op_dispatch_dims>( - a_ptr, - b_ptr, - out_ptr, - dim, - size, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::ScalarVector: - binary_op_dispatch_dims>( - a_ptr, - b_ptr, - out_ptr, - dim, - size, - new_shape, - a_strides, - b_strides, - strides); - break; - default: - binary_op_dispatch_dims( - a_ptr, - b_ptr, - out_ptr, - dim, - size, - new_shape, - a_strides, - b_strides, - strides); - break; - } - }); -} - -template -void binary_op(const array& a, const array& b, array& out) { - binary_op(a, b, out); -} - -template -void binary_op(const array& a, const array& b, array& out, Op op) { - binary_op(a, b, out); -} - -template -void binary(const array& a, const array& b, array& out, Op op) { - switch (out.dtype()) { - case bool_: - binary_op(a, b, out); + switch (bopt) { + case BinaryOpType::VectorVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); break; - case uint8: - binary_op(a, b, out); + case BinaryOpType::VectorScalar: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); break; - case uint16: - binary_op(a, b, out); + case BinaryOpType::ScalarVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); break; - case uint32: - binary_op(a, b, out); - break; - case uint64: - binary_op(a, b, out); - break; - case int8: - binary_op(a, b, out); - break; - case int16: - binary_op(a, b, out); - break; - case int32: - binary_op(a, b, out); - break; - case int64: - binary_op(a, b, out); - break; - case float16: - binary_op(a, b, out); - break; - case float32: - binary_op(a, b, out); - break; - case float64: - binary_op(a, b, out); - break; - case bfloat16: - binary_op(a, b, out); - break; - case complex64: - binary_op(a, b, out); + default: + binary_op_dispatch_dims( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); break; } } +template +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { + binary_op(a, b, out, bopt); +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/binary_two.h b/mlx/backend/cpu/binary_two.h index a89f5aa115..fa0ca7996e 100644 --- a/mlx/backend/cpu/binary_two.h +++ b/mlx/backend/cpu/binary_two.h @@ -4,8 +4,6 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/binary.h" -#include "mlx/backend/cpu/encoder.h" -#include "mlx/primitives.h" namespace mlx::core { @@ -57,14 +55,7 @@ void binary_op_dispatch_dims( const array& b, array& out_a, array& out_b, - Stream stream, Op op) { - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out_a); - encoder.set_output_array(out_b); - auto [shape, strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), out_a.strides()}); const T* a_ptr = a.data(); @@ -72,197 +63,101 @@ void binary_op_dispatch_dims( U* out_a_ptr = out_a.data(); U* out_b_ptr = out_b.data(); - encoder.dispatch([a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - size = a.size(), - shape = std::move(shape), - strides = std::move(strides), - op = std::move(op)]() { - const auto& a_strides = strides[0]; - const auto& b_strides = strides[1]; - const auto& out_strides = strides[2]; - int ndim = shape.size(); - switch (ndim) { - case 1: - binary_op_dims( - a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - case 2: - binary_op_dims( - a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - } - - ContiguousIterator a_it(shape, a_strides, ndim - 2); - ContiguousIterator b_it(shape, b_strides, ndim - 2); - auto stride = out_strides[ndim - 3]; - for (size_t elem = 0; elem < size; elem += stride) { - binary_op_dims( - a_ptr + a_it.loc, - b_ptr + b_it.loc, - out_a_ptr + elem, - out_b_ptr + elem, + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& out_strides = strides[2]; + int ndim = shape.size(); + switch (ndim) { + case 1: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, op, shape, a_strides, b_strides, out_strides, - ndim - 2); - a_it.step(); - b_it.step(); - } - }); + 0); + return; + case 2: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_a_ptr + elem, + out_b_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + } } template void binary_op( const array& a, const array& b, - std::vector& outputs, - Op op) { - auto bopt = get_binary_op_type(a, b); - auto& out_a = outputs[0]; - auto& out_b = outputs[1]; - set_binary_op_output_data(a, b, out_a, bopt); - set_binary_op_output_data(a, b, out_b, bopt); - - auto stream = out_a.primitive().stream(); + array& out_a, + array& out_b, + Op op, + BinaryOpType bopt) { // The full computation is scalar scalar so call the base op once if (bopt == BinaryOpType::General) { - binary_op_dispatch_dims(a, b, out_a, out_b, stream, op); + binary_op_dispatch_dims(a, b, out_a, out_b, op); return; } - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out_a); - encoder.set_output_array(out_b); - auto a_ptr = a.data(); auto b_ptr = b.data(); auto out_a_ptr = out_a.data(); auto out_b_ptr = out_b.data(); if (bopt == BinaryOpType::ScalarScalar) { - encoder.dispatch( - [a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - }); + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); } else if (bopt == BinaryOpType::ScalarVector) { - encoder.dispatch([a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - size = b.size(), - op = std::move(op)]() mutable { - for (size_t i = 0; i < size; ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - b_ptr++; - } - }); + for (size_t i = 0; i < b.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + b_ptr++; + } } else if (bopt == BinaryOpType::VectorScalar) { - encoder.dispatch([a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - size = a.size(), - op = std::move(op)]() mutable { - for (size_t i = 0; i < size; ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - a_ptr++; - } - }); + for (size_t i = 0; i < a.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + } } else { // VectorVector - encoder.dispatch([a_ptr, - b_ptr, - out_a_ptr, - out_b_ptr, - size = a.size(), - op = std::move(op)]() mutable { - for (size_t i = 0; i < size; ++i) { - std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); - out_a_ptr++; - out_b_ptr++; - a_ptr++; - b_ptr++; - } - }); - } -} - -template -void binary( - const array& a, - const array& b, - std::vector& outputs, - Op op) { - switch (outputs[0].dtype()) { - case bool_: - binary_op(a, b, outputs, op); - break; - case uint8: - binary_op(a, b, outputs, op); - break; - case uint16: - binary_op(a, b, outputs, op); - break; - case uint32: - binary_op(a, b, outputs, op); - break; - case uint64: - binary_op(a, b, outputs, op); - break; - case int8: - binary_op(a, b, outputs, op); - break; - case int16: - binary_op(a, b, outputs, op); - break; - case int32: - binary_op(a, b, outputs, op); - break; - case int64: - binary_op(a, b, outputs, op); - break; - case float16: - binary_op(a, b, outputs, op); - break; - case float32: - binary_op(a, b, outputs, op); - break; - case float64: - binary_op(a, b, outputs, op); - break; - case bfloat16: - binary_op(a, b, outputs, op); - break; - case complex64: - binary_op(a, b, outputs, op); - break; + for (size_t i = 0; i < a.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + b_ptr++; + } } } diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index bb2b3a2274..f9b8595dd5 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -13,29 +13,20 @@ namespace mlx::core { namespace { template -void copy_single(const array& src, array& dst, Stream stream) { +void copy_single(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(src); - encoder.set_output_array(dst); - encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() { - auto val = static_cast(src_ptr[0]); - std::fill_n(dst_ptr, size, val); - }); + auto size = dst.size(); + auto val = static_cast(src_ptr[0]); + std::fill_n(dst_ptr, size, val); } template -void copy_vector(const array& src, array& dst, Stream stream) { +void copy_vector(const array& src, array& dst) { auto src_ptr = src.data(); auto dst_ptr = dst.data(); - size_t size = src.data_size(); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(src); - encoder.set_output_array(dst); - encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() { - std::copy(src_ptr, src_ptr + size, dst_ptr); - }); + auto size = src.data_size(); + std::copy(src_ptr, src_ptr + size, dst_ptr); } template @@ -66,7 +57,6 @@ template void copy_general_general( const array& src, array& dst, - Stream stream, const Shape& data_shape, const Strides& i_strides, const Strides& o_strides, @@ -80,47 +70,17 @@ void copy_general_general( dynamic_i_offset ? dynamic_i_offset->data() : nullptr; auto o_offset_ptr = dynamic_o_offset ? dynamic_o_offset->data() : nullptr; + auto size = src.size(); + if (data_shape.empty()) { + auto val = static_cast(*src_ptr); + *dst_ptr = val; + return; + } + auto [shape, strides] = + collapse_contiguous_dims(data_shape, {i_strides, o_strides}); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(src); - encoder.set_output_array(dst); - encoder.dispatch([src_ptr, - dst_ptr, - size = src.size(), - data_shape = data_shape, - i_strides = i_strides, - o_strides = o_strides, - i_offset_ptr, - o_offset_ptr]() mutable { - if (data_shape.empty()) { - auto val = static_cast(*src_ptr); - *dst_ptr = val; - return; - } - auto [shape, strides] = - collapse_contiguous_dims(data_shape, {i_strides, o_strides}); - - int ndim = shape.size(); - if (ndim < 3) { - if (i_offset_ptr) { - src_ptr += i_offset_ptr[0]; - } - if (o_offset_ptr) { - dst_ptr += o_offset_ptr[0]; - } - - if (ndim == 1) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - } else if (ndim == 2) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - } else if (ndim == 3) { - copy_dims( - src_ptr, dst_ptr, shape, strides[0], strides[1], 0); - } - return; - } + int ndim = shape.size(); + if (ndim < 3) { if (i_offset_ptr) { src_ptr += i_offset_ptr[0]; } @@ -128,30 +88,47 @@ void copy_general_general( dst_ptr += o_offset_ptr[0]; } - ContiguousIterator in(shape, strides[0], ndim - 3); - ContiguousIterator out(shape, strides[1], ndim - 3); - auto stride = std::accumulate( - shape.end() - 3, shape.end(), 1, std::multiplies()); - for (int64_t elem = 0; elem < size; elem += stride) { + if (ndim == 1) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + } else if (ndim == 2) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + } else if (ndim == 3) { copy_dims( - src_ptr + in.loc, - dst_ptr + out.loc, - shape, - strides[0], - strides[1], - ndim - 3); - in.step(); - out.step(); + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); } - }); + return; + } + if (i_offset_ptr) { + src_ptr += i_offset_ptr[0]; + } + if (o_offset_ptr) { + dst_ptr += o_offset_ptr[0]; + } + + ContiguousIterator in(shape, strides[0], ndim - 3); + ContiguousIterator out(shape, strides[1], ndim - 3); + auto stride = std::accumulate( + shape.end() - 3, shape.end(), 1, std::multiplies()); + for (int64_t elem = 0; elem < size; elem += stride) { + copy_dims( + src_ptr + in.loc, + dst_ptr + out.loc, + shape, + strides[0], + strides[1], + ndim - 3); + in.step(); + out.step(); + } } template -inline void copy_general_general(const array& src, array& dst, Stream stream) { +inline void copy_general_general(const array& src, array& dst) { copy_general_general( src, dst, - stream, src.shape(), src.strides(), dst.strides(), @@ -165,7 +142,6 @@ template void copy_general( const array& src, array& dst, - Stream stream, const Shape& data_shape, const Strides& i_strides, const Strides&, @@ -176,7 +152,6 @@ void copy_general( copy_general_general( src, dst, - stream, data_shape, i_strides, make_contiguous_strides(data_shape), @@ -187,11 +162,10 @@ void copy_general( } template -inline void copy_general(const array& src, array& dst, Stream stream) { +inline void copy_general(const array& src, array& dst) { copy_general_general( src, dst, - stream, src.shape(), src.strides(), make_contiguous_strides(src.shape()), @@ -202,84 +176,67 @@ inline void copy_general(const array& src, array& dst, Stream stream) { } template -void copy( - const array& src, - array& dst, - CopyType ctype, - Stream stream, - Args&&... args) { +void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { case CopyType::Scalar: - copy_single(src, dst, stream); + copy_single(src, dst); return; case CopyType::Vector: - copy_vector(src, dst, stream); + copy_vector(src, dst); return; case CopyType::General: - copy_general(src, dst, stream, std::forward(args)...); + copy_general(src, dst, std::forward(args)...); return; case CopyType::GeneralGeneral: - copy_general_general( - src, dst, stream, std::forward(args)...); + copy_general_general(src, dst, std::forward(args)...); return; } } template -void copy( - const array& src, - array& dst, - CopyType ctype, - Stream stream, - Args&&... args) { +void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (dst.dtype()) { case bool_: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint16: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint32: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint64: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int8: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int16: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int32: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float16: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float32: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case complex64: - copy( - src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; } } @@ -289,50 +246,49 @@ inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, - Stream stream, Args&&... args) { switch (src.dtype()) { case bool_: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint16: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint32: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case uint64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int8: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int16: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int32: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case int64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float16: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float32: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case float64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; case complex64: - copy(src, dst, ctype, stream, std::forward(args)...); + copy(src, dst, ctype, std::forward(args)...); break; } } @@ -340,7 +296,13 @@ inline void copy_inplace_dispatch( } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { - copy_inplace_dispatch(src, dst, ctype, stream); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_output_array(dst); + encoder.dispatch( + [src = array::unsafe_weak_copy(src), + dst = array::unsafe_weak_copy(dst), + ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); }); } void copy(const array& src, array& dst, CopyType ctype, Stream stream) { @@ -368,26 +330,47 @@ void copy_inplace( Stream stream, const std::optional& dynamic_i_offset, /* = std::nullopt */ const std::optional& dynamic_o_offset /* = std::nullopt */) { - switch (ctype) { - case CopyType::General: - case CopyType::GeneralGeneral: - copy_inplace_dispatch( - src, - dst, - ctype, - stream, - data_shape, - i_strides, - o_strides, - i_offset, - o_offset, - dynamic_i_offset, - dynamic_o_offset); - break; - case CopyType::Scalar: - case CopyType::Vector: - copy_inplace_dispatch(src, dst, ctype, stream); - } + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(src); + encoder.set_output_array(dst); + auto weak_copy_if_set = [](auto x) -> std::optional { + if (x) { + return array::unsafe_weak_copy(*x); + } else { + return std::nullopt; + } + }; + encoder.dispatch( + [src = array::unsafe_weak_copy(src), + dst = array::unsafe_weak_copy(dst), + data_shape, + i_strides, + o_strides, + i_offset, + o_offset, + ctype, + dynamic_i_offset = weak_copy_if_set(dynamic_i_offset), + dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable { + switch (ctype) { + case CopyType::General: + case CopyType::GeneralGeneral: + copy_inplace_dispatch( + src, + dst, + ctype, + data_shape, + i_strides, + o_strides, + i_offset, + o_offset, + dynamic_i_offset, + dynamic_o_offset); + break; + case CopyType::Scalar: + case CopyType::Vector: + copy_inplace_dispatch(src, dst, ctype); + } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 0f4c082cc2..6a32dc1d40 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -22,14 +22,47 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) { return idx; } +struct None { + template + void operator()(T x, T* y) { + (*y) = x; + } +}; +struct Sum { + template + void operator()(T x, T* y) { + (*y) += x; + } +}; + +struct Prod { + template + void operator()(T x, T* y) { + (*y) *= x; + } +}; + +struct Max { + template + void operator()(T x, T* y) { + (*y) = (*y > x) ? *y : x; + } +}; + +struct Min { + template + void operator()(T x, T* y) { + (*y) = (*y < x) ? *y : x; + } +}; + template void gather( const array& src, const std::vector& inds, array& out, const std::vector& axes, - const Shape& slice_sizes, - Stream stream) { + const Shape& slice_sizes) { // If the array is row contiguous then we can do a contiguous copy given // two conditions on the slice size: // - Any number of leading ones in the slice sizes are allowed @@ -82,53 +115,32 @@ void gather( src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } - std::vector ind_ptrs; - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(src); - for (auto& idx : inds) { - ind_ptrs.push_back(idx.data()); - encoder.set_input_array(idx); - } - encoder.set_output_array(out); - encoder.dispatch([src_ptr, - dst_ptr, - ind_ptrs = std::move(ind_ptrs), - axes, - ind_size, - slice_size, - src_shape = src.shape(), - src_strides = src.strides(), - src_it = std::move(src_it), - its = std::move(its), - can_copy]() mutable { - size_t out_idx = 0; - for (int idx = 0; idx < ind_size; idx++) { - size_t src_idx = 0; - for (int ii = 0; ii < ind_ptrs.size(); ++ii) { - auto ax = axes[ii]; - auto idx_loc = its[ii].loc; - its[ii].step(); - auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]); - src_idx += (idx_val * src_strides[ax]); - } - - if (slice_size == 1) { - dst_ptr[out_idx++] = src_ptr[src_idx]; - } else if (can_copy) { - std::copy( - src_ptr + src_idx, - src_ptr + src_idx + slice_size, - dst_ptr + out_idx); - out_idx += slice_size; - } else { - for (int jj = 0; jj < slice_size; jj++) { - dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; - src_it.step(); - } - src_it.reset(); - } + size_t out_idx = 0; + for (int idx = 0; idx < ind_size; idx++) { + size_t src_idx = 0; + for (int ii = 0; ii < inds.size(); ++ii) { + auto ax = axes[ii]; + auto idx_loc = its[ii].loc; + its[ii].step(); + auto idx_val = + offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); + src_idx += (idx_val * src.strides()[ax]); } - }); + + if (slice_size == 1) { + dst_ptr[out_idx++] = src_ptr[src_idx]; + } else if (can_copy) { + std::copy( + src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); + out_idx += slice_size; + } else { + for (int jj = 0; jj < slice_size; jj++) { + dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; + src_it.step(); + } + src_it.reset(); + } + } } template @@ -137,50 +149,49 @@ void dispatch_gather( const std::vector& inds, array& out, const std::vector& axes, - const Shape& size, - Stream stream) { + const Shape& size) { switch (out.dtype()) { case bool_: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case uint8: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case uint16: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case uint32: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case uint64: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case int8: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case int16: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case int32: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case int64: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case float16: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case float32: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case float64: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case bfloat16: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; case complex64: - gather(src, inds, out, axes, size, stream); + gather(src, inds, out, axes, size); break; } } @@ -189,51 +200,63 @@ void Gather::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto& src = inputs[0]; - std::vector inds(inputs.begin() + 1, inputs.end()); - - if (inds.empty()) { - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - return; + std::vector inds; + for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) { + inds.push_back(array::unsafe_weak_copy(*it)); } - - switch (inds[0].dtype()) { - case uint8: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case uint16: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case uint32: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case uint64: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case int8: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case int16: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case int32: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - case int64: - dispatch_gather(src, inds, out, axes_, slice_sizes_, stream()); - break; - default: - throw std::runtime_error( - "[Gather::eval_cpu] Cannot gather with indices type."); - break; + auto& encoder = cpu::get_command_encoder(stream()); + for (auto& in : inputs) { + encoder.set_input_array(in); } + encoder.set_output_array(out); + encoder.dispatch([axes_ = axes_, + slice_sizes_ = slice_sizes_, + src = array::unsafe_weak_copy(src), + inds = std::move(inds), + out = array::unsafe_weak_copy(out)]() mutable { + if (inds.empty()) { + dispatch_gather(src, inds, out, axes_, slice_sizes_); + return; + } + + switch (inds[0].dtype()) { + case uint8: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint16: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint32: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint64: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int8: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int16: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int32: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int64: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + default: + throw std::runtime_error( + "[Gather::eval_cpu] Cannot gather with indices type."); + break; + } + }); } template void gather_axis( const array& src, const array& ind, array& out, - const int axis, - Stream stream) { + const int axis) { auto strides = ind.strides(); strides.erase(strides.begin() + axis); auto shape = ind.shape(); @@ -262,38 +285,20 @@ void gather_axis( size_post *= ind.shape(i); } - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(src); - encoder.set_input_array(ind); - encoder.set_output_array(out); - - encoder.dispatch([ind_ptr, - src_ptr, - dst_ptr, - size_pre, - size_post, - ind_ax_size, - src_ax_size, - ind_ax_stride, - src_ax_stride, - dst_ax_stride, - ind_it = std::move(ind_it), - src_it = std::move(src_it)]() mutable { - size_t stride_pre = size_post * ind_ax_size; - for (size_t i = 0; i < size_pre; i++) { - for (size_t k = 0; k < size_post; k++) { - for (int j = 0; j < ind_ax_size; ++j) { - auto ind_val = offset_neg_idx( - ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); - dst_ptr[k + j * dst_ax_stride] = - src_ptr[src_it.loc + ind_val * src_ax_stride]; - } - ind_it.step(); - src_it.step(); + size_t stride_pre = size_post * ind_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < ind_ax_size; ++j) { + auto ind_val = offset_neg_idx( + ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); + dst_ptr[k + j * dst_ax_stride] = + src_ptr[src_it.loc + ind_val * src_ax_stride]; } - dst_ptr += stride_pre; + ind_it.step(); + src_it.step(); } - }); + dst_ptr += stride_pre; + } } template @@ -301,88 +306,97 @@ void dispatch_gather_axis( const array& src, const array& inds, array& out, - const int axis, - Stream stream) { + const int axis) { switch (out.dtype()) { case bool_: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case uint8: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case uint16: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case uint32: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case uint64: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case int8: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case int16: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case int32: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case int64: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case float16: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case float32: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case float64: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case bfloat16: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; case complex64: - gather_axis(src, inds, out, axis, stream); + gather_axis(src, inds, out, axis); break; } } void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& src = inputs[0]; auto& inds = inputs[1]; - switch (inds.dtype()) { - case uint8: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case uint16: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case uint32: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case uint64: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case int8: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case int16: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case int32: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - case int64: - dispatch_gather_axis(src, inds, out, axis_, stream()); - break; - default: - throw std::runtime_error( - "[GatherAxis::eval_cpu] Cannot gather with indices type."); - break; - } + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(src); + encoder.set_input_array(inds); + encoder.set_output_array(out); + encoder.dispatch([axis_ = axis_, + src = array::unsafe_weak_copy(src), + inds = array::unsafe_weak_copy(inds), + out = array::unsafe_weak_copy(out)]() mutable { + switch (inds.dtype()) { + case uint8: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint16: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint32: + dispatch_gather_axis(src, inds, out, axis_); + break; + case uint64: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int8: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int16: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int32: + dispatch_gather_axis(src, inds, out, axis_); + break; + case int64: + dispatch_gather_axis(src, inds, out, axis_); + break; + default: + throw std::runtime_error( + "[GatherAxis::eval_cpu] Cannot gather with indices type."); + break; + } + }); } template @@ -390,9 +404,7 @@ void scatter( const array& updates, array& out, const std::vector& inds, - const std::vector& axes, - const OpT& op, - Stream stream) { + const std::vector& axes) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); size_t n_updates = nind ? inds[0].size() : 1; @@ -408,45 +420,27 @@ void scatter( ContiguousIterator update_it(updates); ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); - std::vector ind_ptrs; - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(updates); - for (auto& idx : inds) { - ind_ptrs.push_back(idx.data()); - encoder.set_input_array(idx); - } - encoder.set_output_array(out); - encoder.dispatch([out_ptr = out.data(), - upd_ptr = updates.data(), - ind_ptrs = std::move(ind_ptrs), - axes, - n_updates, - update_size, - op = std::move(op), - out_shape = out.shape(), - out_strides = out.strides(), - out_it = std::move(out_it), - update_it = std::move(update_it), - its = std::move(its)]() mutable { - for (int i = 0; i < n_updates; ++i) { - size_t out_offset = 0; - for (int j = 0; j < ind_ptrs.size(); ++j) { - auto ax = axes[j]; - auto idx_loc = its[j].loc; - its[j].step(); - auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]); - out_offset += (idx_val * out_strides[ax]); - } - update_it.seek(i * update_size); - for (int j = 0; j < update_size; ++j) { - op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); - update_it.step(); - out_it.step(); - } - out_it.reset(); - update_it.reset(); + auto out_ptr = out.data(); + auto upd_ptr = updates.data(); + for (int i = 0; i < n_updates; ++i) { + size_t out_offset = 0; + for (int j = 0; j < inds.size(); ++j) { + auto ax = axes[j]; + auto idx_loc = its[j].loc; + its[j].step(); + auto idx_val = + offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); + out_offset += (idx_val * out.strides()[ax]); } - }); + update_it.seek(i * update_size); + for (int j = 0; j < update_size; ++j) { + OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); + update_it.step(); + out_it.step(); + } + out_it.reset(); + update_it.reset(); + } } template @@ -455,53 +449,22 @@ void dispatch_scatter_inds( const std::vector& indices, const array& updates, const std::vector& axes, - Scatter::ReduceType rtype, - Stream stream) { + Scatter::ReduceType rtype) { switch (rtype) { case Scatter::None: - scatter( - updates, - out, - indices, - axes, - [](auto x, auto* y) { (*y) = x; }, - stream); + scatter(updates, out, indices, axes); break; case Scatter::Sum: - scatter( - updates, - out, - indices, - axes, - [](auto x, auto* y) { (*y) += x; }, - stream); + scatter(updates, out, indices, axes); break; case Scatter::Prod: - scatter( - updates, - out, - indices, - axes, - [](auto x, auto* y) { (*y) *= x; }, - stream); + scatter(updates, out, indices, axes); break; case Scatter::Max: - scatter( - updates, - out, - indices, - axes, - [](auto x, auto* y) { (*y) = (*y > x) ? *y : x; }, - stream); + scatter(updates, out, indices, axes); break; case Scatter::Min: - scatter( - updates, - out, - indices, - axes, - [](auto x, auto* y) { (*y) = (*y < x) ? *y : x; }, - stream); + scatter(updates, out, indices, axes); break; } } @@ -512,46 +475,36 @@ void dispatch_scatter( const std::vector& inds, const array& updates, const std::vector& axes, - Scatter::ReduceType rtype, - Stream stream) { + Scatter::ReduceType rtype) { if (inds.empty()) { - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); return; } switch (inds[0].dtype()) { case uint8: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint16: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint32: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint64: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int8: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int16: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int32: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int64: - dispatch_scatter_inds( - out, inds, updates, axes, rtype, stream); + dispatch_scatter_inds(out, inds, updates, axes, rtype); break; default: throw std::runtime_error( @@ -563,7 +516,6 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; - std::vector inds(inputs.begin() + 1, inputs.end() - 1); auto& updates = inputs.back(); // Copy src into out (copy allocates memory for out) @@ -571,73 +523,68 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy(src, out, ctype, stream()); - switch (src.dtype()) { - case bool_: - dispatch_scatter(out, inds, updates, axes_, reduce_type_, stream()); - break; - case uint8: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case uint16: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case uint32: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case uint64: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case int8: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case int16: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case int32: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case int64: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case float16: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case float32: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case float64: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case bfloat16: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; - case complex64: - dispatch_scatter( - out, inds, updates, axes_, reduce_type_, stream()); - break; + auto& encoder = cpu::get_command_encoder(stream()); + std::vector inds; + for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) { + encoder.set_input_array(*it); + inds.push_back(array::unsafe_weak_copy(*it)); } + encoder.set_input_array(updates); + encoder.set_output_array(out); + encoder.dispatch([axes_ = axes_, + reduce_type_ = reduce_type_, + updates = array::unsafe_weak_copy(updates), + inds = std::move(inds), + out = array::unsafe_weak_copy(out)]() mutable { + switch (out.dtype()) { + case bool_: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint8: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int8: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case float16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case float32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case float64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case bfloat16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case complex64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + } + }); } template -void scatter_axis( - array& out, - const array idx, - const array& upd, - int axis, - const OpT& op, - Stream stream) { +void scatter_axis(array& out, const array idx, const array& upd, int axis) { auto strides = idx.strides(); strides.erase(strides.begin() + axis); auto shape = idx.shape(); @@ -657,11 +604,6 @@ void scatter_axis( auto idx_ax_size = idx.shape(axis); auto dst_ax_size = out.shape(axis); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(idx); - encoder.set_input_array(upd); - encoder.set_output_array(out); - size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { @@ -670,34 +612,21 @@ void scatter_axis( for (int i = axis + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } - encoder.dispatch([idx_ptr, - upd_ptr, - dst_ptr, - size_pre, - size_post, - idx_ax_size, - dst_ax_size, - idx_ax_stride, - upd_ax_stride, - dst_ax_stride, - idx_it = std::move(idx_it), - upd_it = std::move(upd_it), - op = std::move(op)]() mutable { - size_t stride_pre = size_post * dst_ax_size; - for (size_t i = 0; i < size_pre; i++) { - for (size_t k = 0; k < size_post; k++) { - for (int j = 0; j < idx_ax_size; ++j) { - auto ind_val = offset_neg_idx( - idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); - op(upd_ptr[upd_it.loc + j * upd_ax_stride], - dst_ptr + k + ind_val * dst_ax_stride); - } - idx_it.step(); - upd_it.step(); + size_t stride_pre = size_post * dst_ax_size; + for (size_t i = 0; i < size_pre; i++) { + for (size_t k = 0; k < size_post; k++) { + for (int j = 0; j < idx_ax_size; ++j) { + auto ind_val = offset_neg_idx( + idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); + OpT{}( + upd_ptr[upd_it.loc + j * upd_ax_stride], + dst_ptr + k + ind_val * dst_ax_stride); } - dst_ptr += stride_pre; + idx_it.step(); + upd_it.step(); } - }); + dst_ptr += stride_pre; + } } template @@ -706,16 +635,13 @@ void dispatch_scatter_axis_op( const array& idx, const array& updates, int axis, - ScatterAxis::ReduceType rtype, - Stream stream) { + ScatterAxis::ReduceType rtype) { switch (rtype) { case ScatterAxis::None: - scatter_axis( - out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream); + scatter_axis(out, idx, updates, axis); break; case ScatterAxis::Sum: - scatter_axis( - out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream); + scatter_axis(out, idx, updates, axis); break; } } @@ -726,40 +652,31 @@ void dispatch_scatter_axis( const array& idx, const array& updates, int axis, - ScatterAxis::ReduceType rtype, - Stream stream) { + ScatterAxis::ReduceType rtype) { switch (idx.dtype()) { case uint8: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint16: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint32: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint64: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int8: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int16: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int32: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int64: - dispatch_scatter_axis_op( - out, idx, updates, axis, rtype, stream); + dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; default: throw std::runtime_error( @@ -779,64 +696,63 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy(src, out, ctype, stream()); - switch (src.dtype()) { - case bool_: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case uint8: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case uint16: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case uint32: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case uint64: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case int8: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case int16: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case int32: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case int64: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case float16: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case float32: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case float64: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case bfloat16: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - case complex64: - dispatch_scatter_axis( - out, idx, updates, axis_, reduce_type_, stream()); - break; - } + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(idx); + encoder.set_input_array(updates); + encoder.set_output_array(out); + encoder.dispatch([axis_ = axis_, + reduce_type_ = reduce_type_, + idx = array::unsafe_weak_copy(idx), + updates = array::unsafe_weak_copy(updates), + out = array::unsafe_weak_copy(out)]() mutable { + switch (out.dtype()) { + case bool_: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint8: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case uint64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int8: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int16: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case int64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case float16: + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_); + break; + case float32: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case float64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; + case bfloat16: + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_); + break; + case complex64: + dispatch_scatter_axis( + out, idx, updates, axis_, reduce_type_); + break; + } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 526f0b7168..02bddab2ff 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -326,8 +326,7 @@ void _qmm_dispatch_typed( const array& biases, int bits, int group_size, - bool transposed_w, - Stream stream) { + bool transposed_w) { int K = x.shape(-1); int M = x.ndim() > 1 ? x.shape(-2) : 1; int N = out.shape(-1); @@ -335,56 +334,25 @@ void _qmm_dispatch_typed( int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; int batch_size = x.size() / (K * M); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - encoder.set_input_array(biases); - encoder.set_output_array(out); - auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); auto scales_ptr = scales.data(); auto biases_ptr = biases.data(); - - encoder.dispatch([out_ptr, - x_ptr, - w_ptr, - scales_ptr, - biases_ptr, - x_shape = x.shape(), - x_strides = x.strides(), - w_shape = w.shape(), - w_strides = w.strides(), - scales_shape = scales.shape(), - scales_strides = scales.strides(), - biases_shape = biases.shape(), - biases_strides = biases.strides(), - w_els, - g_els, - batch_size, - M, - N, - K, - bits, - group_size, - transposed_w] { - for (int i = 0; i < batch_size; i++) { - _qmm_dispatch_typed( - out_ptr + i * M * N, - x_ptr + elem_to_loc(i * M * K, x_shape, x_strides), - w_ptr + elem_to_loc(i * w_els, w_shape, w_strides), - scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides), - biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides), - M, - N, - K, - bits, - group_size, - transposed_w); - } - }); + for (int i = 0; i < batch_size; i++) { + _qmm_dispatch_typed( + out_ptr + i * M * N, + x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), + scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()), + biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()), + M, + N, + K, + bits, + group_size, + transposed_w); + } } void _qmm_dispatch( @@ -395,20 +363,19 @@ void _qmm_dispatch( const array& biases, int bits, int group_size, - bool transposed_w, - Stream stream) { + bool transposed_w) { switch (x.dtype()) { case float32: _qmm_dispatch_typed( - out, x, w, scales, biases, bits, group_size, transposed_w, stream); + out, x, w, scales, biases, bits, group_size, transposed_w); break; case float16: _qmm_dispatch_typed( - out, x, w, scales, biases, bits, group_size, transposed_w, stream); + out, x, w, scales, biases, bits, group_size, transposed_w); break; case bfloat16: _qmm_dispatch_typed( - out, x, w, scales, biases, bits, group_size, transposed_w, stream); + out, x, w, scales, biases, bits, group_size, transposed_w); break; default: throw std::invalid_argument( @@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed( const array& rhs_indices, int bits, int group_size, - bool transposed_w, - Stream stream) { + bool transposed_w) { int K = x.shape(-1); int M = x.shape(-2); int N = out.shape(-1); @@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed( int w_els = w.shape(-1) * w.shape(-2); int g_els = scales.shape(-1) * scales.shape(-2); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - encoder.set_input_array(biases); - encoder.set_input_array(lhs_indices); - encoder.set_input_array(rhs_indices); - encoder.set_output_array(out); - auto out_ptr = out.data(); auto x_ptr = x.data(); auto w_ptr = w.data(); @@ -453,53 +410,26 @@ void _bs_qmm_dispatch_typed( auto lhs_indices_ptr = lhs_indices.data(); auto rhs_indices_ptr = rhs_indices.data(); - encoder.dispatch([out_ptr, - x_ptr, - w_ptr, - scales_ptr, - biases_ptr, - lhs_indices_ptr, - rhs_indices_ptr, - x_shape = x.shape(), - x_strides = x.strides(), - w_shape = w.shape(), - w_strides = w.strides(), - scales_shape = scales.shape(), - scales_strides = scales.strides(), - biases_shape = biases.shape(), - biases_strides = biases.strides(), - lhs_indices_shape = lhs_indices.shape(), - lhs_indices_strides = lhs_indices.strides(), - rhs_indices_shape = rhs_indices.shape(), - rhs_indices_strides = rhs_indices.strides(), - w_els, - g_els, - indices_size = lhs_indices.size(), - M, - N, - K, - bits, - group_size, - transposed_w]() { - for (int i = 0; i < indices_size; i++) { - int x_idx = lhs_indices_ptr[elem_to_loc( - i, lhs_indices_shape, lhs_indices_strides)]; - int w_idx = rhs_indices_ptr[elem_to_loc( - i, rhs_indices_shape, rhs_indices_strides)]; - _qmm_dispatch_typed( - out_ptr + i * M * N, - x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides), - w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides), - scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides), - biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides), - M, - N, - K, - bits, - group_size, - transposed_w); - } - }); + for (int i = 0; i < lhs_indices.size(); i++) { + int x_idx = lhs_indices_ptr[elem_to_loc( + i, lhs_indices.shape(), lhs_indices.strides())]; + int w_idx = rhs_indices_ptr[elem_to_loc( + i, rhs_indices.shape(), rhs_indices.strides())]; + _qmm_dispatch_typed( + out_ptr + i * M * N, + x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), + scales_ptr + + elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()), + biases_ptr + + elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()), + M, + N, + K, + bits, + group_size, + transposed_w); + } } void _bs_qmm_dispatch( @@ -512,8 +442,7 @@ void _bs_qmm_dispatch( const array& rhs_indices, int bits, int group_size, - bool transposed_w, - Stream stream) { + bool transposed_w) { switch (x.dtype()) { case float32: _bs_qmm_dispatch_typed( @@ -526,8 +455,7 @@ void _bs_qmm_dispatch( rhs_indices, bits, group_size, - transposed_w, - stream); + transposed_w); break; case float16: _bs_qmm_dispatch_typed( @@ -540,8 +468,7 @@ void _bs_qmm_dispatch( rhs_indices, bits, group_size, - transposed_w, - stream); + transposed_w); break; case bfloat16: _bs_qmm_dispatch_typed( @@ -554,8 +481,7 @@ void _bs_qmm_dispatch( rhs_indices, bits, group_size, - transposed_w, - stream); + transposed_w); break; default: throw std::invalid_argument( @@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto biases = ensure_row_contiguous(biases_pre); out.set_data(allocator::malloc_or_wait(out.nbytes())); - _qmm_dispatch( - out, x, w, scales, biases, group_size_, bits_, transpose_, stream()); - auto& enc = cpu::get_command_encoder(stream()); - enc.add_temporaries(std::move(temps)); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.add_temporaries(std::move(temps)); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_output_array(out); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); + }); } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { @@ -626,20 +566,38 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto biases = ensure_row_contiguous_last_dims(biases_pre); out.set_data(allocator::malloc_or_wait(out.nbytes())); - _bs_qmm_dispatch( - out, - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - group_size_, - bits_, - transpose_, - stream()); - auto& enc = cpu::get_command_encoder(stream()); - enc.add_temporaries(std::move(temps)); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.add_temporaries(std::move(temps)); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _bs_qmm_dispatch( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + group_size_, + bits_, + transpose_); + }); } template @@ -709,27 +667,13 @@ void dispatch_quantize( array& scales, array& biases, int bits, - int group_size, - Stream stream) { + int group_size) { auto w_ptr = w.data(); auto out_ptr = out.data(); auto scales_ptr = scales.data(); auto biases_ptr = biases.data(); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(w); - encoder.set_input_array(scales); - encoder.set_input_array(biases); - encoder.set_output_array(out); - encoder.dispatch([w_ptr, - out_ptr, - scales_ptr, - biases_ptr, - bits, - group_size, - w_size = w.size()]() { - quantize( - w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size); - }); + quantize( + w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); } void fast::AffineQuantize::eval_cpu( @@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu( auto& biases = outputs[2]; scales.set_data(allocator::malloc_or_wait(scales.nbytes())); biases.set_data(allocator::malloc_or_wait(biases.nbytes())); - if (w.dtype() == float16) { - if (is_power_of_2(bits_)) { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } else { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } - } else if (w.dtype() == bfloat16) { - if (is_power_of_2(bits_)) { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } else { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } - } else if (w.dtype() == float32) { - if (is_power_of_2(bits_)) { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } else { - dispatch_quantize( - w, out, scales, biases, bits_, group_size_, stream()); - } - } else { - throw std::runtime_error( - "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); - } + auto& encoder = cpu::get_command_encoder(stream()); if (copied) { - cpu::get_command_encoder(stream()).add_temporary(w); + encoder.add_temporary(w); } + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_output_array(out); + encoder.dispatch([w = array::unsafe_weak_copy(w), + out = array::unsafe_weak_copy(out), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + group_size_ = group_size_, + bits_ = bits_]() mutable { + if (w.dtype() == float16) { + if (is_power_of_2(bits_)) { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } else { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } + } else if (w.dtype() == bfloat16) { + if (is_power_of_2(bits_)) { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } else { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } + } else if (w.dtype() == float32) { + if (is_power_of_2(bits_)) { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } else { + dispatch_quantize( + w, out, scales, biases, bits_, group_size_); + } + } else { + throw std::runtime_error( + "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); + } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 424894cfdd..3f0c3b2aec 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -140,34 +140,23 @@ void reduction_op( const array& x, array& out, const std::vector& axes, - U init, - Stream stream) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + U init) { ReductionPlan plan = get_reduction_plan(x, axes); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(x); - encoder.set_output_array(out); - auto in_ptr = x.data(); auto out_ptr = out.data(); if (plan.type == ContiguousAllReduce) { - encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() { - *out_ptr = init; - contiguous_reduce(in_ptr, out_ptr, size, Op{}, init); - }); + *out_ptr = init; + contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init); return; } if (plan.type == ContiguousReduce && plan.shape.size() == 1) { int reduction_size = plan.shape[0]; - encoder.dispatch( - [in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable { - for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) { - *out_ptr = init; - contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); - } - }); + for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) { + *out_ptr = init; + contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); + } return; } @@ -178,40 +167,29 @@ void reduction_op( // Unrolling the following loop (and implementing it in order for // ContiguousReduce) should hold extra performance boost. auto [shape, strides] = shapes_without_reduction_axes(x, axes); - - encoder.dispatch([in_ptr, - out_ptr, - init, - reduction_size, - size = out.size(), - plan = std::move(plan), - shape = std::move(shape), - strides = std::move(strides)]() mutable { - if (plan.shape.size() == 0) { - for (int i = 0; i < size; i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - contiguous_reduce( - in_ptr + offset, out_ptr, reduction_size, Op{}, init); - } - } else { - for (int i = 0; i < size; i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - nd_loop( - [&](int extra_offset) { - contiguous_reduce( - in_ptr + offset + extra_offset, - out_ptr, - reduction_size, - Op{}, - init); - }, - plan.shape, - plan.strides); - } + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init); } - }); + } else { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + nd_loop( + [&](int extra_offset) { + contiguous_reduce( + in_ptr + offset + extra_offset, + out_ptr, + reduction_size, + Op{}, + init); + }, + plan.shape, + plan.strides); + } + } return; } @@ -220,20 +198,12 @@ void reduction_op( size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); - - encoder.dispatch([in_ptr, - out_ptr, - init, - reduction_size, - reduction_stride, - size = out.size()]() mutable { - for (int i = 0; i < size; i += reduction_stride) { - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); - in_ptr += reduction_stride * reduction_size; - out_ptr += reduction_stride; - } - }); + for (int i = 0; i < out.size(); i += reduction_stride) { + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); + in_ptr += reduction_stride * reduction_size; + out_ptr += reduction_stride; + } return; } @@ -245,67 +215,49 @@ void reduction_op( plan.strides.pop_back(); auto [shape, strides] = shapes_without_reduction_axes(x, axes); - encoder.dispatch([in_ptr, - out_ptr, - init, - reduction_size, - reduction_stride, - size = out.size(), - plan = std::move(plan), - shape = std::move(shape), - strides = std::move(strides)]() mutable { - if (plan.shape.size() == 0) { - for (int i = 0; i < size; i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce( - in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{}); - out_ptr += reduction_stride; - } - } else { - for (int i = 0; i < size; i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - nd_loop( - [&](int extra_offset) { - strided_reduce( - in_ptr + offset + extra_offset, - out_ptr, - reduction_size, - reduction_stride, - Op{}); - }, - plan.shape, - plan.strides); - out_ptr += reduction_stride; - } + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce( + in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{}); + out_ptr += reduction_stride; } - }); + } else { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + nd_loop( + [&](int extra_offset) { + strided_reduce( + in_ptr + offset + extra_offset, + out_ptr, + reduction_size, + reduction_stride, + Op{}); + }, + plan.shape, + plan.strides); + out_ptr += reduction_stride; + } + } return; } if (plan.type == GeneralReduce) { auto [shape, strides] = shapes_without_reduction_axes(x, axes); - encoder.dispatch([in_ptr, - out_ptr, - init, - size = out.size(), - plan = std::move(plan), - shape = std::move(shape), - strides = std::move(strides)]() mutable { - for (int i = 0; i < size; i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - U val = init; - nd_loop( - [&](int extra_offset) { - val = Op{}(val, *(in_ptr + offset + extra_offset)); - }, - plan.shape, - plan.strides); - *out_ptr = val; - } - }); + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + U val = init; + nd_loop( + [&](int extra_offset) { + val = Op{}(val, *(in_ptr + offset + extra_offset)); + }, + plan.shape, + plan.strides); + *out_ptr = val; + } } } @@ -434,12 +386,11 @@ void reduce_dispatch_and_or( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes, - Stream stream) { + const std::vector& axes) { if (rtype == Reduce::And) { - reduction_op(in, out, axes, true, stream); + reduction_op(in, out, axes, true); } else { - reduction_op(in, out, axes, false, stream); + reduction_op(in, out, axes, false); } } @@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes, - Stream stream) { + const std::vector& axes) { if (rtype == Reduce::Sum) { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 0, stream); + reduction_op(in, out, axes, 0); } else { - reduction_op(in, out, axes, 0, stream); + reduction_op(in, out, axes, 0); } } else { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 1, stream); + reduction_op(in, out, axes, 1); } else { - reduction_op(in, out, axes, 1, stream); + reduction_op(in, out, axes, 1); } } } @@ -470,162 +420,144 @@ void reduce_dispatch_min_max( const array& in, array& out, Reduce::ReduceType rtype, - const std::vector& axes, - Stream stream) { + const std::vector& axes) { if (rtype == Reduce::Max) { auto init = Limits::min; - reduction_op(in, out, axes, init, stream); + reduction_op(in, out, axes, init); } else { auto init = Limits::max; - reduction_op(in, out, axes, init, stream); + reduction_op(in, out, axes, init); } } void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (reduce_type_) { - case Reduce::And: - case Reduce::Or: { - switch (in.dtype()) { - case bool_: - case uint8: - case int8: - reduce_dispatch_and_or( - in, out, reduce_type_, axes_, stream()); - break; - case int16: - case uint16: - case float16: - case bfloat16: - reduce_dispatch_and_or( - in, out, reduce_type_, axes_, stream()); - break; - case uint32: - case int32: - case float32: - reduce_dispatch_and_or( - in, out, reduce_type_, axes_, stream()); - break; - case uint64: - case int64: - case float64: - case complex64: - reduce_dispatch_and_or( - in, out, reduce_type_, axes_, stream()); - break; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + reduce_type_ = reduce_type_, + axes_ = axes_]() mutable { + switch (reduce_type_) { + case Reduce::And: + case Reduce::Or: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + case float16: + case bfloat16: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint32: + case int32: + case float32: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint64: + case int64: + case float64: + case complex64: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + } + break; } - break; - } - case Reduce::Sum: - case Reduce::Prod: { - switch (in.dtype()) { - case bool_: - case uint8: - case int8: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case int16: - case uint16: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case int32: - case uint32: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case int64: - case uint64: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case float16: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case bfloat16: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case float32: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case float64: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; - case complex64: - reduce_dispatch_sum_prod( - in, out, reduce_type_, axes_, stream()); - break; + case Reduce::Sum: + case Reduce::Prod: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int32: + case uint32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int64: + case uint64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + } + break; } - break; - } - case Reduce::Max: - case Reduce::Min: { - switch (in.dtype()) { - case bool_: - reduce_dispatch_min_max(in, out, reduce_type_, axes_, stream()); - break; - case uint8: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case uint16: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case uint32: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case uint64: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case int8: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case int16: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case int32: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case int64: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case float16: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case float32: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case float64: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case bfloat16: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; - case complex64: - reduce_dispatch_min_max( - in, out, reduce_type_, axes_, stream()); - break; + case Reduce::Max: + case Reduce::Min: { + switch (in.dtype()) { + case bool_: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + } + break; } - break; } - } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 93f4387183..205ae414da 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -160,38 +160,29 @@ void scan_op( bool reverse, bool inclusive, const Op& op, - U init, - Stream stream) { - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(in); - encoder.set_output_array(out); - + U init) { if (in.flags().row_contiguous) { if (in.strides()[axis] == 1) { - encoder.dispatch([in_ptr = in.data(), - out_ptr = out.data(), - count = in.size() / in.shape(axis), - stride = in.shape(axis), - reverse, - inclusive, - op = std::move(op), - init]() { - contiguous_scan( - in_ptr, out_ptr, count, stride, reverse, inclusive, op, init); - }); + contiguous_scan( + in.data(), + out.data(), + in.size() / in.shape(axis), + in.shape(axis), + reverse, + inclusive, + op, + init); } else { - encoder.dispatch([in_ptr = in.data(), - out_ptr = out.data(), - count = in.size() / in.shape(axis) / in.strides()[axis], - size = in.shape(axis), - stride = in.strides()[axis], - reverse, - inclusive, - op = std::move(op), - init]() { - strided_scan( - in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init); - }); + strided_scan( + in.data(), + out.data(), + in.size() / in.shape(axis) / in.strides()[axis], + in.shape(axis), + in.strides()[axis], + reverse, + inclusive, + op, + init); } } else { throw std::runtime_error("Scan op supports only contiguous inputs"); @@ -205,19 +196,18 @@ void scan_dispatch( array& out, int axis, bool reverse, - bool inclusive, - Stream stream) { + bool inclusive) { switch (rtype) { case Scan::Sum: { auto op = [](U y, T x) { return y + x; }; auto init = static_cast(0); - scan_op(in, out, axis, reverse, inclusive, op, init, stream); + scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Prod: { auto op = [](U y, T x) { return y * x; }; auto init = static_cast(1); - scan_op(in, out, axis, reverse, inclusive, op, init, stream); + scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Min: { @@ -225,7 +215,7 @@ void scan_dispatch( auto init = (issubdtype(in.dtype(), floating)) ? static_cast(std::numeric_limits::infinity()) : std::numeric_limits::max(); - scan_op(in, out, axis, reverse, inclusive, op, init, stream); + scan_op(in, out, axis, reverse, inclusive, op, init); break; } case Scan::Max: { @@ -233,7 +223,7 @@ void scan_dispatch( auto init = (issubdtype(in.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::min(); - scan_op(in, out, axis, reverse, inclusive, op, init, stream); + scan_op(in, out, axis, reverse, inclusive, op, init); break; } } @@ -244,88 +234,95 @@ void scan_dispatch( void Scan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); + auto& encoder = cpu::get_command_encoder(stream()); + // Ensure contiguity auto in = inputs[0]; - bool copied = false; if (!in.flags().row_contiguous) { array arr_copy(in.shape(), in.dtype(), nullptr, {}); copy(in, arr_copy, CopyType::General, stream()); in = arr_copy; - copied = true; + encoder.add_temporary(arr_copy); } out.set_data(allocator::malloc_or_wait(out.nbytes())); - switch (in.dtype()) { - case bool_: { - // We could do a full dtype x dtype switch but this is the only case - // where we accumulate in a different type, for now. - // - // TODO: If we add the option to accumulate floats in higher precision - // floats perhaps we should add the full all-to-all dispatch. - if (reduce_type_ == Scan::Sum && out.dtype() == int32) { - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - } else { - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); + encoder.set_input_array(in); + encoder.set_output_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + axis_ = axis_, + reduce_type_ = reduce_type_, + reverse_ = reverse_, + inclusive_ = inclusive_]() mutable { + switch (in.dtype()) { + case bool_: { + // We could do a full dtype x dtype switch but this is the only case + // where we accumulate in a different type, for now. + // + // TODO: If we add the option to accumulate floats in higher precision + // floats perhaps we should add the full all-to-all dispatch. + if (reduce_type_ == Scan::Sum && out.dtype() == int32) { + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + } else { + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + } + break; } - break; + case uint8: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int8: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case float16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case float32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case float64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case bfloat16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case complex64: + throw std::runtime_error("Scan ops do not support complex types yet"); + break; } - case uint8: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case uint16: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case uint32: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case uint64: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case int8: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case int16: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case int32: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case int64: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case float16: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case float32: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case float64: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case bfloat16: - scan_dispatch( - reduce_type_, in, out, axis_, reverse_, inclusive_, stream()); - break; - case complex64: - throw std::runtime_error("Scan ops do not support complex types yet"); - break; - } - if (copied) { - cpu::get_command_encoder(stream()).add_temporary(std::move(in)); - } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/select.cpp b/mlx/backend/cpu/select.cpp index 1382a8ff64..bf6a9b8259 100644 --- a/mlx/backend/cpu/select.cpp +++ b/mlx/backend/cpu/select.cpp @@ -16,51 +16,70 @@ void select_op( const array& b, const array& c, array& out, - Op op) { - switch (out.dtype()) { - case bool_: - ternary_op(a, b, c, out, op); - break; - case uint8: - ternary_op(a, b, c, out, op); - break; - case uint16: - ternary_op(a, b, c, out, op); - break; - case uint32: - ternary_op(a, b, c, out, op); - break; - case uint64: - ternary_op(a, b, c, out, op); - break; - case int8: - ternary_op(a, b, c, out, op); - break; - case int16: - ternary_op(a, b, c, out, op); - break; - case int32: - ternary_op(a, b, c, out, op); - break; - case int64: - ternary_op(a, b, c, out, op); - break; - case float16: - ternary_op(a, b, c, out, op); - break; - case float32: - ternary_op(a, b, c, out, op); - break; - case float64: - ternary_op(a, b, c, out, op); - break; - case bfloat16: - ternary_op(a, b, c, out, op); - break; - case complex64: - ternary_op(a, b, c, out, op); - break; - } + Op op, + Stream stream) { + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + c = array::unsafe_weak_copy(c), + out = array::unsafe_weak_copy(out), + op, + topt]() mutable { + switch (out.dtype()) { + case bool_: + ternary_op(a, b, c, out, op, topt); + break; + case uint8: + ternary_op(a, b, c, out, op, topt); + break; + case uint16: + ternary_op(a, b, c, out, op, topt); + break; + case uint32: + ternary_op(a, b, c, out, op, topt); + break; + case uint64: + ternary_op(a, b, c, out, op, topt); + break; + case int8: + ternary_op(a, b, c, out, op, topt); + break; + case int16: + ternary_op(a, b, c, out, op, topt); + break; + case int32: + ternary_op(a, b, c, out, op, topt); + break; + case int64: + ternary_op(a, b, c, out, op, topt); + break; + case float16: + ternary_op( + a, b, c, out, op, topt); + break; + case float32: + ternary_op(a, b, c, out, op, topt); + break; + case float64: + ternary_op(a, b, c, out, op, topt); + break; + case bfloat16: + ternary_op( + a, b, c, out, op, topt); + break; + case complex64: + ternary_op( + a, b, c, out, op, topt); + break; + } + }); } } // namespace @@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector& inputs, array& out) { const auto& condition = inputs[0]; const auto& a = inputs[1]; const auto& b = inputs[2]; - select_op(condition, a, b, out, detail::Select()); + select_op(condition, a, b, out, detail::Select(), stream()); } } // namespace mlx::core diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index f66e7362a6..4439df61bc 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -105,15 +105,11 @@ struct StridedIterator { }; template -void sort(const array& in, array& out, int axis, Stream stream) { - // Copy input to output - CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype, stream); - +void sort(array& out, int axis) { // Get axis, shape and stride info - axis = axis < 0 ? axis + in.ndim() : axis; - size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); - size_t n_rows = in_size / in.shape(axis); + axis = axis < 0 ? axis + out.ndim() : axis; + size_t in_size = out.size(); + size_t n_rows = in_size / out.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); @@ -127,30 +123,20 @@ void sort(const array& in, array& out, int axis, Stream stream) { // Perform sorting in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_output_array(out); - encoder.dispatch([out_ptr = out.data(), - src_it = std::move(src_it), - n_rows, - axis_size, - axis_stride]() mutable { - for (int i = 0; i < n_rows; i++) { - T* data_ptr = out_ptr + src_it.loc; + auto out_ptr = out.data(); + for (int i = 0; i < n_rows; i++) { + T* data_ptr = out_ptr + src_it.loc; - StridedIterator st(data_ptr, axis_stride, 0); - StridedIterator ed(data_ptr, axis_stride, axis_size); + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator ed(data_ptr, axis_stride, axis_size); - std::stable_sort(st, ed); - src_it.step(); - } - }); + std::stable_sort(st, ed); + src_it.step(); + } } template -void argsort(const array& in, array& out, int axis, Stream stream) { - // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); - +void argsort(const array& in, array& out, int axis) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); @@ -176,99 +162,69 @@ void argsort(const array& in, array& out, int axis, Stream stream) { in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(in); - encoder.set_input_array(out); - encoder.dispatch([in_ptr = in.data(), - out_ptr = out.data(), - in_it = std::move(in_it), - out_it = std::move(out_it), - n_rows, - axis_size, - in_stride, - out_stride]() mutable { - for (int i = 0; i < n_rows; i++) { - const T* data_ptr = in_ptr + in_it.loc; - IdxT* idx_ptr = out_ptr + out_it.loc; + auto in_ptr = in.data(); + auto out_ptr = out.data(); + for (int i = 0; i < n_rows; i++) { + const T* data_ptr = in_ptr + in_it.loc; + IdxT* idx_ptr = out_ptr + out_it.loc; - in_it.step(); - out_it.step(); + in_it.step(); + out_it.step(); - StridedIterator st_(idx_ptr, out_stride, 0); - StridedIterator ed_(idx_ptr, out_stride, axis_size); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); - // Initialize with iota - std::iota(st_, ed_, IdxT(0)); + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); - // Sort according to vals - StridedIterator st(idx_ptr, out_stride, 0); - StridedIterator ed(idx_ptr, out_stride, axis_size); + // Sort according to vals + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator ed(idx_ptr, out_stride, axis_size); - std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * in_stride]; - auto v2 = data_ptr[b * in_stride]; - return v1 < v2 || (v1 == v2 && a < b); - }); - } - }); + std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } } template -void partition(const array& in, array& out, int axis, int kth, Stream stream) { - // Copy input to output - CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; - copy(in, out, ctype, stream); - +void partition(array& out, int axis, int kth) { // Get axis, shape and stride info - axis = axis < 0 ? axis + in.ndim() : axis; - size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); - size_t n_rows = in_size / in.shape(axis); + axis = axis < 0 ? axis + out.ndim() : axis; + size_t in_size = out.size(); + size_t n_rows = in_size / out.shape(axis); - auto remaining_shape = in.shape(); + auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); - auto remaining_strides = in.strides(); + auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); - auto axis_stride = in.strides()[axis]; - int axis_size = in.shape(axis); + auto axis_stride = out.strides()[axis]; + int axis_size = out.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_output_array(out); - encoder.dispatch([out_ptr = out.data(), - src_it = std::move(src_it), - n_rows, - axis_size, - axis_stride, - kth]() mutable { - for (int i = 0; i < n_rows; i++) { - T* data_ptr = out_ptr + src_it.loc; - src_it.step(); + auto out_ptr = out.data(); + for (int i = 0; i < n_rows; i++) { + T* data_ptr = out_ptr + src_it.loc; + src_it.step(); - StridedIterator st(data_ptr, axis_stride, 0); - StridedIterator md(data_ptr, axis_stride, kth); - StridedIterator ed(data_ptr, axis_stride, axis_size); + StridedIterator st(data_ptr, axis_stride, 0); + StridedIterator md(data_ptr, axis_stride, kth); + StridedIterator ed(data_ptr, axis_stride, axis_size); - std::nth_element(st, md, ed); - } - }); + std::nth_element(st, md, ed); + } } template -void argpartition( - const array& in, - array& out, - int axis, - int kth, - Stream stream) { - // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); - +void argpartition(const array& in, array& out, int axis, int kth) { // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); @@ -297,42 +253,32 @@ void argpartition( ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); - auto& encoder = cpu::get_command_encoder(stream); - encoder.set_input_array(in); - encoder.set_input_array(out); - encoder.dispatch([in_ptr = in.data(), - out_ptr = out.data(), - in_it = std::move(in_it), - out_it = std::move(out_it), - n_rows, - axis_size, - in_stride, - out_stride, - kth]() mutable { - for (int i = 0; i < n_rows; i++) { - const T* data_ptr = in_ptr + in_it.loc; - IdxT* idx_ptr = out_ptr + out_it.loc; - in_it.step(); - out_it.step(); + auto in_ptr = in.data(); + auto out_ptr = out.data(); - StridedIterator st_(idx_ptr, out_stride, 0); - StridedIterator ed_(idx_ptr, out_stride, axis_size); + for (int i = 0; i < n_rows; i++) { + const T* data_ptr = in_ptr + in_it.loc; + IdxT* idx_ptr = out_ptr + out_it.loc; + in_it.step(); + out_it.step(); - // Initialize with iota - std::iota(st_, ed_, IdxT(0)); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); - // Sort according to vals - StridedIterator st(idx_ptr, out_stride, 0); - StridedIterator md(idx_ptr, out_stride, kth); - StridedIterator ed(idx_ptr, out_stride, axis_size); + // Initialize with iota + std::iota(st_, ed_, IdxT(0)); - std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * in_stride]; - auto v2 = data_ptr[b * in_stride]; - return v1 < v2 || (v1 == v2 && a < b); - }); - } - }); + // Sort according to vals + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator md(idx_ptr, out_stride, kth); + StridedIterator ed(idx_ptr, out_stride, axis_size); + + std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; + return v1 < v2 || (v1 == v2 && a < b); + }); + } } } // namespace @@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (in.dtype()) { - case bool_: - return argsort(in, out, axis_, stream()); - case uint8: - return argsort(in, out, axis_, stream()); - case uint16: - return argsort(in, out, axis_, stream()); - case uint32: - return argsort(in, out, axis_, stream()); - case uint64: - return argsort(in, out, axis_, stream()); - case int8: - return argsort(in, out, axis_, stream()); - case int16: - return argsort(in, out, axis_, stream()); - case int32: - return argsort(in, out, axis_, stream()); - case int64: - return argsort(in, out, axis_, stream()); - case float32: - return argsort(in, out, axis_, stream()); - case float64: - return argsort(in, out, axis_, stream()); - case float16: - return argsort(in, out, axis_, stream()); - case bfloat16: - return argsort(in, out, axis_, stream()); - case complex64: - return argsort(in, out, axis_, stream()); - } + // Allocate output + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_input_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + axis_ = axis_]() mutable { + switch (in.dtype()) { + case bool_: + return argsort(in, out, axis_); + case uint8: + return argsort(in, out, axis_); + case uint16: + return argsort(in, out, axis_); + case uint32: + return argsort(in, out, axis_); + case uint64: + return argsort(in, out, axis_); + case int8: + return argsort(in, out, axis_); + case int16: + return argsort(in, out, axis_); + case int32: + return argsort(in, out, axis_); + case int64: + return argsort(in, out, axis_); + case float32: + return argsort(in, out, axis_); + case float64: + return argsort(in, out, axis_); + case float16: + return argsort(in, out, axis_); + case bfloat16: + return argsort(in, out, axis_); + case complex64: + return argsort(in, out, axis_); + } + }); } void Sort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (in.dtype()) { - case bool_: - return sort(in, out, axis_, stream()); - case uint8: - return sort(in, out, axis_, stream()); - case uint16: - return sort(in, out, axis_, stream()); - case uint32: - return sort(in, out, axis_, stream()); - case uint64: - return sort(in, out, axis_, stream()); - case int8: - return sort(in, out, axis_, stream()); - case int16: - return sort(in, out, axis_, stream()); - case int32: - return sort(in, out, axis_, stream()); - case int64: - return sort(in, out, axis_, stream()); - case float32: - return sort(in, out, axis_, stream()); - case float64: - return sort(in, out, axis_, stream()); - case float16: - return sort(in, out, axis_, stream()); - case bfloat16: - return sort(in, out, axis_, stream()); - case complex64: - return sort(in, out, axis_, stream()); - } + // Copy input to output + CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy(in, out, ctype, stream()); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_output_array(out); + encoder.dispatch( + [out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable { + switch (out.dtype()) { + case bool_: + return sort(out, axis_); + case uint8: + return sort(out, axis_); + case uint16: + return sort(out, axis_); + case uint32: + return sort(out, axis_); + case uint64: + return sort(out, axis_); + case int8: + return sort(out, axis_); + case int16: + return sort(out, axis_); + case int32: + return sort(out, axis_); + case int64: + return sort(out, axis_); + case float32: + return sort(out, axis_); + case float64: + return sort(out, axis_); + case float16: + return sort(out, axis_); + case bfloat16: + return sort(out, axis_); + case complex64: + return sort(out, axis_); + } + }); } void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (in.dtype()) { - case bool_: - return argpartition(in, out, axis_, kth_, stream()); - case uint8: - return argpartition(in, out, axis_, kth_, stream()); - case uint16: - return argpartition(in, out, axis_, kth_, stream()); - case uint32: - return argpartition(in, out, axis_, kth_, stream()); - case uint64: - return argpartition(in, out, axis_, kth_, stream()); - case int8: - return argpartition(in, out, axis_, kth_, stream()); - case int16: - return argpartition(in, out, axis_, kth_, stream()); - case int32: - return argpartition(in, out, axis_, kth_, stream()); - case int64: - return argpartition(in, out, axis_, kth_, stream()); - case float32: - return argpartition(in, out, axis_, kth_, stream()); - case float64: - return argpartition(in, out, axis_, kth_, stream()); - case float16: - return argpartition(in, out, axis_, kth_, stream()); - case bfloat16: - return argpartition(in, out, axis_, kth_, stream()); - case complex64: - return argpartition(in, out, axis_, kth_, stream()); - } + // Allocate output + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_input_array(in); + encoder.set_input_array(out); + encoder.dispatch([in = array::unsafe_weak_copy(in), + out = array::unsafe_weak_copy(out), + axis_ = axis_, + kth_ = kth_]() mutable { + switch (in.dtype()) { + case bool_: + return argpartition(in, out, axis_, kth_); + case uint8: + return argpartition(in, out, axis_, kth_); + case uint16: + return argpartition(in, out, axis_, kth_); + case uint32: + return argpartition(in, out, axis_, kth_); + case uint64: + return argpartition(in, out, axis_, kth_); + case int8: + return argpartition(in, out, axis_, kth_); + case int16: + return argpartition(in, out, axis_, kth_); + case int32: + return argpartition(in, out, axis_, kth_); + case int64: + return argpartition(in, out, axis_, kth_); + case float32: + return argpartition(in, out, axis_, kth_); + case float64: + return argpartition(in, out, axis_, kth_); + case float16: + return argpartition(in, out, axis_, kth_); + case bfloat16: + return argpartition(in, out, axis_, kth_); + case complex64: + return argpartition(in, out, axis_, kth_); + } + }); } void Partition::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (in.dtype()) { - case bool_: - return partition(in, out, axis_, kth_, stream()); - case uint8: - return partition(in, out, axis_, kth_, stream()); - case uint16: - return partition(in, out, axis_, kth_, stream()); - case uint32: - return partition(in, out, axis_, kth_, stream()); - case uint64: - return partition(in, out, axis_, kth_, stream()); - case int8: - return partition(in, out, axis_, kth_, stream()); - case int16: - return partition(in, out, axis_, kth_, stream()); - case int32: - return partition(in, out, axis_, kth_, stream()); - case int64: - return partition(in, out, axis_, kth_, stream()); - case float32: - return partition(in, out, axis_, kth_, stream()); - case float64: - return partition(in, out, axis_, kth_, stream()); - case float16: - return partition(in, out, axis_, kth_, stream()); - case bfloat16: - return partition(in, out, axis_, kth_, stream()); - case complex64: - return partition(in, out, axis_, kth_, stream()); - } + // Copy input to output + CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy(in, out, ctype, stream()); + + auto& encoder = cpu::get_command_encoder(stream()); + encoder.set_output_array(out); + encoder.dispatch([out = array::unsafe_weak_copy(out), + axis_ = axis_, + kth_ = kth_]() mutable { + switch (out.dtype()) { + case bool_: + return partition(out, axis_, kth_); + case uint8: + return partition(out, axis_, kth_); + case uint16: + return partition(out, axis_, kth_); + case uint32: + return partition(out, axis_, kth_); + case uint64: + return partition(out, axis_, kth_); + case int8: + return partition(out, axis_, kth_); + case int16: + return partition(out, axis_, kth_); + case int32: + return partition(out, axis_, kth_); + case int64: + return partition(out, axis_, kth_); + case float32: + return partition(out, axis_, kth_); + case float64: + return partition(out, axis_, kth_); + case float16: + return partition(out, axis_, kth_); + case bfloat16: + return partition(out, axis_, kth_); + case complex64: + return partition(out, axis_, kth_); + } + }); } } // namespace mlx::core diff --git a/mlx/backend/cpu/ternary.h b/mlx/backend/cpu/ternary.h index 7b89c7d748..a27a7f2a9f 100644 --- a/mlx/backend/cpu/ternary.h +++ b/mlx/backend/cpu/ternary.h @@ -1,12 +1,10 @@ // Copyright © 2023 Apple Inc. #pragma once -#include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/ternary.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" -#include "mlx/primitives.h" namespace mlx::core { @@ -128,57 +126,28 @@ void ternary_op( const array& b, const array& c, array& out, - Op op) { - TernaryOpType topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt); - - auto& encoder = cpu::get_command_encoder(out.primitive().stream()); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_input_array(c); - encoder.set_output_array(out); - + Op op, + TernaryOpType topt) { const T1* a_ptr = a.data(); const T2* b_ptr = b.data(); const T3* c_ptr = c.data(); U* out_ptr = out.data(); if (topt == TernaryOpType::ScalarScalarScalar) { - encoder.dispatch( - [a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable { - *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); - }); + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); } else if (topt == TernaryOpType::VectorVectorVector) { - encoder.dispatch([a_ptr, - b_ptr, - c_ptr, - out_ptr, - op = std::move(op), - size = out.size()]() mutable { - for (size_t i = 0; i < size; ++i) { - *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); - a_ptr++; - b_ptr++; - c_ptr++; - out_ptr++; - } - }); + for (size_t i = 0; i < out.size(); ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } } else { auto [shape, strides] = collapse_contiguous_dims( a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); - encoder.dispatch( - - [a_ptr, - b_ptr, - c_ptr, - out_ptr, - op = std::move(op), - size = out.size(), - shape = std::move(shape), - strides = std::move(strides)]() mutable { - ternary_op_dispatch_dims( - a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides); - }); + ternary_op_dispatch_dims( + a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides); } } diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index 1b16be3950..89d1cafb3d 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { // No-op for unsigned types out.copy_shared_buffer(in); } else { - auto op = detail::Abs{}; - switch (out.dtype()) { - case int8: - unary_op(in, out, op); - break; - case int16: - unary_op(in, out, op); - break; - case int32: - unary_op(in, out, op); - break; - case int64: - unary_op(in, out, op); - break; - case float16: - unary_op(in, out, op); - break; - case float32: - unary_op(in, out, op); - break; - case float64: - unary_op(in, out, op); - break; - case bfloat16: - unary_op(in, out, op); - break; - case complex64: - unary_op(in, out, op); - break; - default: - throw std::runtime_error("[Abs] Called on unsigned type"); - } + unary_signed(in, out, detail::Abs(), stream()); } } void ArcCos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcCos()); + unary_fp(in, out, detail::ArcCos(), stream()); } void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcCosh()); + unary_fp(in, out, detail::ArcCosh(), stream()); } void ArcSin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcSin()); + unary_fp(in, out, detail::ArcSin(), stream()); } void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcSinh()); + unary_fp(in, out, detail::ArcSinh(), stream()); } void ArcTan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcTan()); + unary_fp(in, out, detail::ArcTan(), stream()); } void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::ArcTanh()); + unary_fp(in, out, detail::ArcTanh(), stream()); } void BitwiseInvert::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_int(in, out, detail::BitwiseInvert()); + unary_int(in, out, detail::BitwiseInvert(), stream()); } void Ceil::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Ceil()); + unary_fp(in, out, detail::Ceil(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector& inputs, array& out) { void Conjugate::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - unary_op(inputs[0], out, detail::Conjugate()); + unary_complex(inputs[0], out, detail::Conjugate(), stream()); } void Cos::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Cos()); + unary_fp(in, out, detail::Cos(), stream()); } void Cosh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Cosh()); + unary_fp(in, out, detail::Cosh(), stream()); } void Erf::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - switch (out.dtype()) { - case float32: - unary_op(in, out, detail::Erf()); - break; - case float16: - unary_op(in, out, detail::Erf()); - break; - case float64: - unary_op(in, out, detail::Erf()); - break; - case bfloat16: - unary_op(in, out, detail::Erf()); - break; - default: - throw std::invalid_argument( - "[erf] Error function only defined for arrays" - " with real floating point type."); - } + unary_real_fp(in, out, detail::Erf(), stream()); } void ErfInv::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - switch (out.dtype()) { - case float32: - unary_op(in, out, detail::ErfInv()); - break; - case float16: - unary_op(in, out, detail::ErfInv()); - break; - case float64: - unary_op(in, out, detail::ErfInv()); - break; - case bfloat16: - unary_op(in, out, detail::ErfInv()); - break; - default: - throw std::invalid_argument( - "[erf_inv] Inverse error function only defined for arrays" - " with real floating point type."); - } + unary_real_fp(in, out, detail::ErfInv(), stream()); } void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Exp()); + unary_fp(in, out, detail::Exp(), stream()); } void Expm1::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Expm1()); + unary_fp(in, out, detail::Expm1(), stream()); } void Floor::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Floor()); + unary_fp(in, out, detail::Floor(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector& inputs, array& out) { } void Imag::eval_cpu(const std::vector& inputs, array& out) { - unary_op(inputs[0], out, detail::Imag()); + unary_complex_to_float(inputs[0], out, detail::Imag(), stream()); } void Log::eval_cpu(const std::vector& inputs, array& out) { @@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (base_) { case Base::e: - unary_fp(in, out, detail::Log()); + unary_fp(in, out, detail::Log(), stream()); break; case Base::two: - unary_fp(in, out, detail::Log2()); + unary_fp(in, out, detail::Log2(), stream()); break; case Base::ten: - unary_fp(in, out, detail::Log10()); + unary_fp(in, out, detail::Log10(), stream()); break; } } @@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector& inputs, array& out) { void Log1p::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Log1p()); + unary_fp(in, out, detail::Log1p(), stream()); } void LogicalNot::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, detail::LogicalNot()); + unary(in, out, detail::LogicalNot(), stream()); } void Negative::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, detail::Negative()); + unary(in, out, detail::Negative(), stream()); } void Real::eval_cpu(const std::vector& inputs, array& out) { - unary_op(inputs[0], out, detail::Real()); + unary_complex_to_float(inputs[0], out, detail::Real(), stream()); } void Round::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_fp(in, out, detail::Round()); + unary_fp(in, out, detail::Round(), stream()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector& inputs, array& out) { void Sigmoid::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Sigmoid()); + unary_fp(in, out, detail::Sigmoid(), stream()); } void Sign::eval_cpu(const std::vector& inputs, array& out) { @@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { out.copy_shared_buffer(in); } else { - unary(in, out, detail::Sign()); + unary(in, out, detail::Sign(), stream()); } } void Sin::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Sin()); + unary_fp(in, out, detail::Sin(), stream()); } void Sinh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Sinh()); + unary_fp(in, out, detail::Sinh(), stream()); } void Square::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, detail::Square()); + unary(in, out, detail::Square(), stream()); } void Sqrt::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (recip_) { - unary_fp(in, out, detail::Rsqrt()); + unary_fp(in, out, detail::Rsqrt(), stream()); } else { - unary_fp(in, out, detail::Sqrt()); + unary_fp(in, out, detail::Sqrt(), stream()); } } void Tan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Tan()); + unary_fp(in, out, detail::Tan(), stream()); } void Tanh::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - unary_fp(in, out, detail::Tanh()); + unary_fp(in, out, detail::Tanh(), stream()); } } // namespace mlx::core diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index 4769ecd06c..a12bbbf009 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -7,7 +7,6 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" -#include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { @@ -39,156 +38,263 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) { template void unary_op(const array& a, array& out, Op) { - set_unary_output_data(a, out); const T* src = a.data(); U* dst = out.data(); - auto& encoder = cpu::get_command_encoder(out.primitive().stream()); + auto ndim = a.ndim(); + if (a.flags().contiguous) { + auto size = a.data_size(); + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::load(src))); + size -= N; + src += N; + dst += N; + } + while (size > 0) { + *dst = Op{}(*src); + size--; + dst++; + src++; + } + } else { + size_t shape = ndim > 0 ? a.shape().back() : 1; + size_t stride = ndim > 0 ? a.strides().back() : 1; + if (ndim <= 1) { + unary_op(src, dst, shape, stride); + return; + } + auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); + for (size_t elem = 0; elem < a.size(); elem += shape) { + unary_op(src + it.loc, dst + elem, shape, stride); + it.step(); + } + } +} + +template +void unary(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); encoder.set_input_array(a); encoder.set_output_array(out); - - encoder.dispatch([src, - dst, - contig = a.flags().contiguous, - data_size = a.data_size(), - size = a.size(), - shapes = a.shape(), - strides = a.strides()]() mutable { - auto ndim = shapes.size(); - if (contig) { - constexpr int N = simd::max_size; - while (data_size >= N) { - simd::store(dst, Op{}(simd::load(src))); - data_size -= N; - src += N; - dst += N; - } - while (data_size > 0) { - *dst = Op{}(*src); - data_size--; - dst++; - src++; - } - } else { - size_t shape = ndim > 0 ? shapes.back() : 1; - size_t stride = ndim > 0 ? strides.back() : 1; - if (ndim <= 1) { - unary_op(src, dst, shape, stride); - return; - } - auto it = ContiguousIterator(shapes, strides, ndim - 1); - for (size_t elem = 0; elem < size; elem += shape) { - unary_op(src + it.loc, dst + elem, shape, stride); - it.step(); - } + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bool_: + unary_op(a, out, op); + break; + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; } }); } template -void unary(const array& a, array& out, Op op) { - switch (out.dtype()) { - case bool_: - unary_op(a, out, op); - break; - case uint8: - unary_op(a, out, op); - break; - case uint16: - unary_op(a, out, op); - break; - case uint32: - unary_op(a, out, op); - break; - case uint64: - unary_op(a, out, op); - break; - case int8: - unary_op(a, out, op); - break; - case int16: - unary_op(a, out, op); - break; - case int32: - unary_op(a, out, op); - break; - case int64: - unary_op(a, out, op); - break; - case float16: - unary_op(a, out, op); - break; - case float32: - unary_op(a, out, op); - break; - case float64: - unary_op(a, out, op); - break; - case bfloat16: - unary_op(a, out, op); - break; - case complex64: - unary_op(a, out, op); - break; - } +void unary_real_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_real] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} +template +void unary_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_fp] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); } template -void unary_fp(const array& a, array& out, Op op) { - switch (out.dtype()) { - case bfloat16: - unary_op(a, out, op); - break; - case float16: - unary_op(a, out, op); - break; - case float32: - unary_op(a, out, op); - break; - case float64: - unary_op(a, out, op); - break; - case complex64: - unary_op(a, out, op); - break; - default: - std::ostringstream err; - err << "[unary_fp] Does not support " << out.dtype(); - throw std::runtime_error(err.str()); - } +void unary_signed(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + throw std::runtime_error("[Abs] Called on unsigned type"); + } + }); } template -void unary_int(const array& a, array& out, Op op) { - switch (out.dtype()) { - case uint8: - unary_op(a, out, op); - break; - case uint16: - unary_op(a, out, op); - break; - case uint32: - unary_op(a, out, op); - break; - case uint64: - unary_op(a, out, op); - break; - case int8: - unary_op(a, out, op); - break; - case int16: - unary_op(a, out, op); - break; - case int32: - unary_op(a, out, op); - break; - case int64: - unary_op(a, out, op); - break; - default: - std::ostringstream err; - err << "[unary_int] Does not support " << out.dtype(); - throw std::runtime_error(err.str()); - } +void unary_complex(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch( + [a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_int(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_int] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); } } // namespace mlx::core