mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 18:56:39 +08:00
reduce binary size (#1952)
This commit is contained in:
parent
117e1355a2
commit
736a340478
@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
|
@ -199,6 +199,13 @@ class array {
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
|
@ -11,12 +11,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(
|
||||
const array& in,
|
||||
array& out,
|
||||
const OpT& op,
|
||||
int axis,
|
||||
Stream stream) {
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
@ -26,18 +21,7 @@ void arg_reduce(
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
op = std::move(op),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
size = out.size()]() {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
@ -47,7 +31,6 @@ void arg_reduce(
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
@ -55,8 +38,7 @@ void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis,
|
||||
Stream stream) {
|
||||
int axis) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
@ -65,7 +47,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
@ -75,7 +57,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -87,51 +69,58 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_two.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -16,51 +17,218 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out) {
|
||||
switch (a.dtype()) {
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out);
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out);
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out);
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out);
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out);
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out);
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out);
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out);
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out);
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out);
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out);
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out);
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out);
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out);
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports non-complex floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
binary(a, b, out, detail::Add(), stream());
|
||||
}
|
||||
|
||||
void DivMod::eval_cpu(
|
||||
@ -78,70 +246,89 @@ void DivMod::eval_cpu(
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out_a = array::unsafe_weak_copy(out_a),
|
||||
out_b = array::unsafe_weak_copy(out_b),
|
||||
bopt]() mutable {
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
|
||||
switch (out_a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, float_op);
|
||||
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
binary(a, b, out, detail::Divide(), stream());
|
||||
}
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
binary(a, b, out, detail::Remainder(), stream());
|
||||
}
|
||||
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (equal_nan_) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out);
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out);
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
comparison_op<detail::Equal>(a, b, out);
|
||||
comparison_op(a, b, out, detail::Equal(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op<detail::Less>(inputs[0], inputs[1], out);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
}
|
||||
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
binary(a, b, out, detail::Maximum(), stream());
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
binary(a, b, out, detail::Minimum(), stream());
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
binary(a, b, out, detail::Multiply(), stream());
|
||||
}
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
binary(a, b, out, detail::Power(), stream());
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract());
|
||||
binary(a, b, out, detail::Subtract(), stream());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
binary_int(a, b, out, detail::RightShift(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,12 +3,9 @@
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
@ -152,30 +149,12 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
|
||||
auto out_ptr = out.data<U>();
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([bopt,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_data_size = a.data_size(),
|
||||
b_data_size = b.data_size(),
|
||||
size = a.size(),
|
||||
shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_strides = b.strides(),
|
||||
strides = out.strides()]() mutable {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||
return;
|
||||
@ -183,29 +162,28 @@ void binary_op(const array& a, const array& b, array& out) {
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
shape,
|
||||
{std::move(a_strides), std::move(b_strides), std::move(strides)});
|
||||
a_strides = new_strides[0];
|
||||
b_strides = new_strides[1];
|
||||
strides = new_strides[2];
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
auto& a_strides = new_strides[0];
|
||||
auto& b_strides = new_strides[1];
|
||||
auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
@ -262,7 +240,7 @@ void binary_op(const array& a, const array& b, array& out) {
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@ -274,7 +252,7 @@ void binary_op(const array& a, const array& b, array& out) {
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@ -286,7 +264,7 @@ void binary_op(const array& a, const array& b, array& out) {
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@ -298,72 +276,18 @@ void binary_op(const array& a, const array& b, array& out) {
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
binary_op<T, T, Op>(a, b, out, bopt);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -57,14 +55,7 @@ void binary_op_dispatch_dims(
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Stream stream,
|
||||
Op op) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const T* a_ptr = a.data<T>();
|
||||
@ -72,14 +63,6 @@ void binary_op_dispatch_dims(
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
op = std::move(op)]() {
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
@ -116,7 +99,7 @@ void binary_op_dispatch_dims(
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < size; elem += stride) {
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
@ -131,138 +114,50 @@ void binary_op_dispatch_dims(
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto stream = out_a.primitive().stream();
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
BinaryOpType bopt) {
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
auto out_a_ptr = out_a.data<U>();
|
||||
auto out_b_ptr = out_b.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
encoder.dispatch(
|
||||
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
});
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = b.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
});
|
||||
} else { // VectorVector
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,29 +13,20 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst, Stream stream) {
|
||||
void copy_single(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
|
||||
auto size = dst.size();
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst, Stream stream) {
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
size_t size = src.data_size();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
|
||||
auto size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
@ -66,7 +57,6 @@ template <typename SrcT, typename DstT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
@ -80,18 +70,7 @@ void copy_general_general(
|
||||
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||
auto o_offset_ptr =
|
||||
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr,
|
||||
dst_ptr,
|
||||
size = src.size(),
|
||||
data_shape = data_shape,
|
||||
i_strides = i_strides,
|
||||
o_strides = o_strides,
|
||||
i_offset_ptr,
|
||||
o_offset_ptr]() mutable {
|
||||
auto size = src.size();
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
*dst_ptr = val;
|
||||
@ -143,15 +122,13 @@ void copy_general_general(
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst, Stream stream) {
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
dst.strides(),
|
||||
@ -165,7 +142,6 @@ template <typename SrcT, typename DstT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
@ -176,7 +152,6 @@ void copy_general(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
@ -187,11 +162,10 @@ void copy_general(
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst, Stream stream) {
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
@ -202,84 +176,67 @@ inline void copy_general(const array& src, array& dst, Stream stream) {
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst, stream);
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::Vector:
|
||||
copy_vector<SrcT, DstT>(src, dst, stream);
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, stream, std::forward<Args>(args)...);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -289,50 +246,49 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -340,7 +296,13 @@ inline void copy_inplace_dispatch(
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
@ -368,6 +330,27 @@ void copy_inplace(
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
|
||||
if (x) {
|
||||
return array::unsafe_weak_copy(*x);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
ctype,
|
||||
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
|
||||
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
@ -375,7 +358,6 @@ void copy_inplace(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
@ -386,8 +368,9 @@ void copy_inplace(
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -22,14 +22,47 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = x;
|
||||
}
|
||||
};
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) += x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) *= x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Max {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Min {
|
||||
template <typename T>
|
||||
void operator()(T x, T* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& slice_sizes,
|
||||
Stream stream) {
|
||||
const Shape& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
@ -82,43 +115,23 @@ void gather(
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
|
||||
std::vector<const IdxT*> ind_ptrs;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
for (auto& idx : inds) {
|
||||
ind_ptrs.push_back(idx.data<IdxT>());
|
||||
encoder.set_input_array(idx);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([src_ptr,
|
||||
dst_ptr,
|
||||
ind_ptrs = std::move(ind_ptrs),
|
||||
axes,
|
||||
ind_size,
|
||||
slice_size,
|
||||
src_shape = src.shape(),
|
||||
src_strides = src.strides(),
|
||||
src_it = std::move(src_it),
|
||||
its = std::move(its),
|
||||
can_copy]() mutable {
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < ind_ptrs.size(); ++ii) {
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val = offset_neg_idx(ind_ptrs[ii][idx_loc], src_shape[ax]);
|
||||
src_idx += (idx_val * src_strides[ax]);
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx,
|
||||
src_ptr + src_idx + slice_size,
|
||||
dst_ptr + out_idx);
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
@ -128,7 +141,6 @@ void gather(
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
@ -137,50 +149,49 @@ void dispatch_gather(
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& size,
|
||||
Stream stream) {
|
||||
const Shape& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint8:
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint16:
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint32:
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint64:
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int8:
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int16:
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int32:
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int64:
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float16:
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float64:
|
||||
gather<double, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<double, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case complex64:
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size, stream);
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -189,51 +200,63 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
std::vector<array> inds;
|
||||
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
|
||||
inds.push_back(array::unsafe_weak_copy(*it));
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
for (auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axes_ = axes_,
|
||||
slice_sizes_ = slice_sizes_,
|
||||
src = array::unsafe_weak_copy(src),
|
||||
inds = std::move(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_, stream());
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename T, typename IdxT>
|
||||
void gather_axis(
|
||||
const array& src,
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis,
|
||||
Stream stream) {
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
@ -262,23 +285,6 @@ void gather_axis(
|
||||
size_post *= ind.shape(i);
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_input_array(ind);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([ind_ptr,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
size_pre,
|
||||
size_post,
|
||||
ind_ax_size,
|
||||
src_ax_size,
|
||||
ind_ax_stride,
|
||||
src_ax_stride,
|
||||
dst_ax_stride,
|
||||
ind_it = std::move(ind_it),
|
||||
src_it = std::move(src_it)]() mutable {
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
@ -293,7 +299,6 @@ void gather_axis(
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
@ -301,88 +306,97 @@ void dispatch_gather_axis(
|
||||
const array& src,
|
||||
const array& inds,
|
||||
array& out,
|
||||
const int axis,
|
||||
Stream stream) {
|
||||
const int axis) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint8:
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint16:
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint32:
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint64:
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int8:
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int16:
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int32:
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int64:
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float16:
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float64:
|
||||
gather_axis<double, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<double, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case complex64:
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis, stream);
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_input_array(inds);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axis_ = axis_,
|
||||
src = array::unsafe_weak_copy(src),
|
||||
inds = array::unsafe_weak_copy(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_, stream());
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT, typename OpT>
|
||||
@ -390,9 +404,7 @@ void scatter(
|
||||
const array& updates,
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
@ -408,45 +420,27 @@ void scatter(
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
std::vector<const IdxT*> ind_ptrs;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(updates);
|
||||
for (auto& idx : inds) {
|
||||
ind_ptrs.push_back(idx.data<IdxT>());
|
||||
encoder.set_input_array(idx);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<InT>(),
|
||||
upd_ptr = updates.data<InT>(),
|
||||
ind_ptrs = std::move(ind_ptrs),
|
||||
axes,
|
||||
n_updates,
|
||||
update_size,
|
||||
op = std::move(op),
|
||||
out_shape = out.shape(),
|
||||
out_strides = out.strides(),
|
||||
out_it = std::move(out_it),
|
||||
update_it = std::move(update_it),
|
||||
its = std::move(its)]() mutable {
|
||||
auto out_ptr = out.data<InT>();
|
||||
auto upd_ptr = updates.data<InT>();
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < ind_ptrs.size(); ++j) {
|
||||
for (int j = 0; j < inds.size(); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val = offset_neg_idx(ind_ptrs[j][idx_loc], out_shape[ax]);
|
||||
out_offset += (idx_val * out_strides[ax]);
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
@ -455,53 +449,22 @@ void dispatch_scatter_inds(
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype,
|
||||
Stream stream) {
|
||||
Scatter::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = x; },
|
||||
stream);
|
||||
scatter<InT, IdxT, None>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) += x; },
|
||||
stream);
|
||||
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) *= x; },
|
||||
stream);
|
||||
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = (*y > x) ? *y : x; },
|
||||
stream);
|
||||
scatter<InT, IdxT, Max>(updates, out, indices, axes);
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(
|
||||
updates,
|
||||
out,
|
||||
indices,
|
||||
axes,
|
||||
[](auto x, auto* y) { (*y) = (*y < x) ? *y : x; },
|
||||
stream);
|
||||
scatter<InT, IdxT, Min>(updates, out, indices, axes);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -512,46 +475,36 @@ void dispatch_scatter(
|
||||
const std::vector<array>& inds,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype,
|
||||
Stream stream) {
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, uint8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_inds<InT, uint16_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_inds<InT, uint32_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_inds<InT, uint64_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_inds<InT, int8_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_inds<InT, int16_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_inds<InT, int32_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(
|
||||
out, inds, updates, axes, rtype, stream);
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@ -563,7 +516,6 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
@ -571,73 +523,68 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
std::vector<array> inds;
|
||||
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
|
||||
encoder.set_input_array(*it);
|
||||
inds.push_back(array::unsafe_weak_copy(*it));
|
||||
}
|
||||
encoder.set_input_array(updates);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axes_ = axes_,
|
||||
reduce_type_ = reduce_type_,
|
||||
updates = array::unsafe_weak_copy(updates),
|
||||
inds = std::move(inds),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(
|
||||
out, inds, updates, axes_, reduce_type_, stream());
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(
|
||||
array& out,
|
||||
const array idx,
|
||||
const array& upd,
|
||||
int axis,
|
||||
const OpT& op,
|
||||
Stream stream) {
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
@ -657,11 +604,6 @@ void scatter_axis(
|
||||
auto idx_ax_size = idx.shape(axis);
|
||||
auto dst_ax_size = out.shape(axis);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_input_array(upd);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
@ -670,26 +612,14 @@ void scatter_axis(
|
||||
for (int i = axis + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
encoder.dispatch([idx_ptr,
|
||||
upd_ptr,
|
||||
dst_ptr,
|
||||
size_pre,
|
||||
size_post,
|
||||
idx_ax_size,
|
||||
dst_ax_size,
|
||||
idx_ax_stride,
|
||||
upd_ax_stride,
|
||||
dst_ax_stride,
|
||||
idx_it = std::move(idx_it),
|
||||
upd_it = std::move(upd_it),
|
||||
op = std::move(op)]() mutable {
|
||||
size_t stride_pre = size_post * dst_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
OpT{}(
|
||||
upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
}
|
||||
idx_it.step();
|
||||
@ -697,7 +627,6 @@ void scatter_axis(
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
@ -706,16 +635,13 @@ void dispatch_scatter_axis_op(
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype,
|
||||
Stream stream) {
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case ScatterAxis::None:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; }, stream);
|
||||
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; }, stream);
|
||||
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -726,40 +652,31 @@ void dispatch_scatter_axis(
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype,
|
||||
Stream stream) {
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (idx.dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis_op<InT, int8_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis_op<InT, int16_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis_op<InT, int32_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis_op<InT, int64_t>(
|
||||
out, idx, updates, axis, rtype, stream);
|
||||
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@ -779,64 +696,63 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
|
||||
switch (src.dtype()) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_input_array(updates);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([axis_ = axis_,
|
||||
reduce_type_ = reduce_type_,
|
||||
idx = array::unsafe_weak_copy(idx),
|
||||
updates = array::unsafe_weak_copy(updates),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_, stream());
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
@ -335,48 +334,18 @@ void _qmm_dispatch_typed(
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w] {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -384,7 +353,6 @@ void _qmm_dispatch_typed(
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
@ -395,20 +363,19 @@ void _qmm_dispatch(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
@ -453,45 +410,19 @@ void _bs_qmm_dispatch_typed(
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
indices_size = lhs_indices.size(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w]() {
|
||||
for (int i = 0; i < indices_size; i++) {
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr +
|
||||
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -499,7 +430,6 @@ void _bs_qmm_dispatch_typed(
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(
|
||||
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -626,6 +566,26 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
@ -636,10 +596,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_,
|
||||
stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -709,27 +667,13 @@ void dispatch_quantize(
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
Stream stream) {
|
||||
int group_size) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w_ptr,
|
||||
out_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
bits,
|
||||
group_size,
|
||||
w_size = w.size()]() {
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
|
||||
});
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
encoder.add_temporary(w);
|
||||
}
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w = array::unsafe_weak_copy(w),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_]() mutable {
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -140,34 +140,23 @@ void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Stream stream) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U init) {
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
|
||||
});
|
||||
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
encoder.dispatch(
|
||||
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@ -178,24 +167,14 @@ void reduction_op(
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
@ -211,7 +190,6 @@ void reduction_op(
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@ -220,20 +198,12 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@ -245,17 +215,8 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
@ -263,7 +224,7 @@ void reduction_op(
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
@ -280,21 +241,13 @@ void reduction_op(
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
@ -305,7 +258,6 @@ void reduction_op(
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
|
||||
} else {
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
|
||||
} else {
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
|
||||
} else {
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -470,20 +420,27 @@ void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axes_ = axes_]() mutable {
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
@ -491,28 +448,24 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@ -523,43 +476,34 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@ -568,64 +512,52 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -160,38 +160,29 @@ void scan_op(
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
U init) {
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.strides()[axis] == 1) {
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis),
|
||||
stride = in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
contiguous_scan(
|
||||
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis) / in.strides()[axis],
|
||||
size = in.shape(axis),
|
||||
stride = in.strides()[axis],
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis),
|
||||
in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
op,
|
||||
init);
|
||||
} else {
|
||||
strided_scan(
|
||||
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis) / in.strides()[axis],
|
||||
in.shape(axis),
|
||||
in.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||
@ -205,19 +196,18 @@ void scan_dispatch(
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
Stream stream) {
|
||||
bool inclusive) {
|
||||
switch (rtype) {
|
||||
case Scan::Sum: {
|
||||
auto op = [](U y, T x) { return y + x; };
|
||||
auto init = static_cast<U>(0);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U y, T x) { return y * x; };
|
||||
auto init = static_cast<U>(1);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
@ -225,7 +215,7 @@ void scan_dispatch(
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
@ -233,7 +223,7 @@ void scan_dispatch(
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -244,17 +234,26 @@ void scan_dispatch(
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
bool copied = false;
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
copied = true;
|
||||
encoder.add_temporary(arr_copy);
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
reduce_type_ = reduce_type_,
|
||||
reverse_ = reverse_,
|
||||
inclusive_ = inclusive_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
@ -264,68 +263,66 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -16,51 +16,70 @@ void select_op(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
Op op,
|
||||
Stream stream) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
c = array::unsafe_weak_copy(c),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op,
|
||||
topt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op);
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
const auto& b = inputs[2];
|
||||
select_op(condition, a, b, out, detail::Select());
|
||||
select_op(condition, a, b, out, detail::Select(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -105,15 +105,11 @@ struct StridedIterator {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@ -127,13 +123,7 @@ void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride]() mutable {
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
@ -143,14 +133,10 @@ void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@ -176,17 +162,8 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride]() mutable {
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
@ -210,42 +187,30 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
kth]() mutable {
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
@ -256,19 +221,10 @@ void partition(const array& in, array& out, int axis, int kth, Stream stream) {
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int kth,
|
||||
Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@ -297,18 +253,9 @@ void argpartition(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride,
|
||||
kth]() mutable {
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
@ -332,7 +279,6 @@ void argpartition(
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_, stream());
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_, stream());
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_, stream());
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_, stream());
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_, stream());
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_, stream());
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_, stream());
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_, stream());
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_, stream());
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_, stream());
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_, stream());
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_, stream());
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_, stream());
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_, stream());
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_, stream());
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_, stream());
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_, stream());
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_, stream());
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_, stream());
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_, stream());
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_, stream());
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_, stream());
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_, stream());
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_, stream());
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_, stream());
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_, stream());
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_, stream());
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_, stream());
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_, stream());
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_, stream());
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_, stream());
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_, stream());
|
||||
return partition<bool>(out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
return partition<uint8_t>(out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
return partition<uint16_t>(out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
return partition<uint32_t>(out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
return partition<uint64_t>(out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_, stream());
|
||||
return partition<int8_t>(out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_, stream());
|
||||
return partition<int16_t>(out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_, stream());
|
||||
return partition<int32_t>(out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_, stream());
|
||||
return partition<int64_t>(out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_, stream());
|
||||
return partition<float>(out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_, stream());
|
||||
return partition<double>(out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_, stream());
|
||||
return partition<float16_t>(out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
return partition<bfloat16_t>(out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
return partition<complex64_t>(out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,12 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -128,57 +126,28 @@ void ternary_op(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
Op op,
|
||||
TernaryOpType topt) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
encoder.dispatch(
|
||||
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
});
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op = std::move(op),
|
||||
size = out.size()]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
c_ptr++;
|
||||
out_ptr++;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
encoder.dispatch(
|
||||
|
||||
[a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op = std::move(op),
|
||||
size = out.size(),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(
|
||||
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
|
||||
});
|
||||
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto op = detail::Abs{};
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(in, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(in, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(in, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(in, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(in, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
unary_signed(in, out, detail::Abs(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
unary_fp(in, out, detail::ArcCos(), stream());
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
unary_fp(in, out, detail::ArcCosh(), stream());
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
unary_fp(in, out, detail::ArcSin(), stream());
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
unary_fp(in, out, detail::ArcSinh(), stream());
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
unary_fp(in, out, detail::ArcTan(), stream());
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
unary_fp(in, out, detail::ArcTanh(), stream());
|
||||
}
|
||||
|
||||
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_int(in, out, detail::BitwiseInvert());
|
||||
unary_int(in, out, detail::BitwiseInvert(), stream());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
unary_fp(in, out, detail::Ceil(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
|
||||
unary_complex(inputs[0], out, detail::Conjugate(), stream());
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cos());
|
||||
unary_fp(in, out, detail::Cos(), stream());
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
unary_fp(in, out, detail::Cosh(), stream());
|
||||
}
|
||||
|
||||
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::Erf(), stream());
|
||||
}
|
||||
|
||||
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
unary_real_fp(in, out, detail::ErfInv(), stream());
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Exp());
|
||||
unary_fp(in, out, detail::Exp(), stream());
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
unary_fp(in, out, detail::Expm1(), stream());
|
||||
}
|
||||
|
||||
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
unary_fp(in, out, detail::Floor(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
unary_fp(in, out, detail::Log(), stream());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
unary_fp(in, out, detail::Log2(), stream());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
unary_fp(in, out, detail::Log10(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
unary_fp(in, out, detail::Log1p(), stream());
|
||||
}
|
||||
|
||||
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
unary(in, out, detail::LogicalNot(), stream());
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
unary(in, out, detail::Negative(), stream());
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
|
||||
}
|
||||
|
||||
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
unary_fp(in, out, detail::Round(), stream());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
unary_fp(in, out, detail::Sigmoid(), stream());
|
||||
}
|
||||
|
||||
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
unary(in, out, detail::Sign(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sin());
|
||||
unary_fp(in, out, detail::Sin(), stream());
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
unary_fp(in, out, detail::Sinh(), stream());
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
unary(in, out, detail::Square(), stream());
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
unary_fp(in, out, detail::Rsqrt(), stream());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
unary_fp(in, out, detail::Sqrt(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tan());
|
||||
unary_fp(in, out, detail::Tan(), stream());
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
unary_fp(in, out, detail::Tanh(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -39,53 +38,48 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op) {
|
||||
set_unary_output_data(a, out);
|
||||
const T* src = a.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([src,
|
||||
dst,
|
||||
contig = a.flags().contiguous,
|
||||
data_size = a.data_size(),
|
||||
size = a.size(),
|
||||
shapes = a.shape(),
|
||||
strides = a.strides()]() mutable {
|
||||
auto ndim = shapes.size();
|
||||
if (contig) {
|
||||
auto ndim = a.ndim();
|
||||
if (a.flags().contiguous) {
|
||||
auto size = a.data_size();
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (data_size >= N) {
|
||||
while (size >= N) {
|
||||
simd::store(dst, Op{}(simd::load<T, N>(src)));
|
||||
data_size -= N;
|
||||
size -= N;
|
||||
src += N;
|
||||
dst += N;
|
||||
}
|
||||
while (data_size > 0) {
|
||||
while (size > 0) {
|
||||
*dst = Op{}(*src);
|
||||
data_size--;
|
||||
size--;
|
||||
dst++;
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? shapes.back() : 1;
|
||||
size_t stride = ndim > 0 ? strides.back() : 1;
|
||||
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(shapes, strides, ndim - 1);
|
||||
for (size_t elem = 0; elem < size; elem += shape) {
|
||||
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary(const array& a, array& out, Op op) {
|
||||
void unary(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
unary_op<bool>(a, out, op);
|
||||
@ -130,10 +124,47 @@ void unary(const array& a, array& out, Op op) {
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_fp(const array& a, array& out, Op op) {
|
||||
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream err;
|
||||
err << "[unary_real] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename Op>
|
||||
void unary_fp(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
@ -155,10 +186,84 @@ void unary_fp(const array& a, array& out, Op op) {
|
||||
err << "[unary_fp] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op) {
|
||||
void unary_signed(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(a, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(a, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(a, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(a, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(a, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_complex(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_int(const array& a, array& out, Op op, Stream stream) {
|
||||
set_unary_output_data(a, out);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op = op]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case uint8:
|
||||
unary_op<uint8_t>(a, out, op);
|
||||
@ -189,6 +294,7 @@ void unary_int(const array& a, array& out, Op op) {
|
||||
err << "[unary_int] Does not support " << out.dtype();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user