mlx/mlx/backend/cpu/indexing.cpp
2025-03-20 16:48:43 -07:00

759 lines
22 KiB
C++

// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
struct None {
template <typename T>
void operator()(T x, T* y) {
(*y) = x;
}
};
struct Sum {
template <typename T>
void operator()(T x, T* y) {
(*y) += x;
}
};
struct Prod {
template <typename T>
void operator()(T x, T* y) {
(*y) *= x;
}
};
struct Max {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y > x) ? *y : x;
}
};
struct Min {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y < x) ? *y : x;
}
};
template <typename T, typename IdxT>
void gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size
// If the array is col contiguous then the reverse is the case:
// - Any number of trailing ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size from the end
bool can_copy = false;
if (src.flags().row_contiguous) {
can_copy = true;
// Ignore leading 1s
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
// Check the remaining
i++;
for (; i < src.ndim() && can_copy; ++i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
} else if (src.flags().col_contiguous) {
can_copy = true;
// Ignore trailing 1s
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
// Skip the next slice size and check the remaining
i--;
for (; i >= 0 && can_copy; --i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
}
template <typename IdxT>
void dispatch_gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size);
break;
}
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
inds.push_back(array::unsafe_weak_copy(*it));
}
auto& encoder = cpu::get_command_encoder(stream());
for (auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
slice_sizes_ = slice_sizes_,
src = array::unsafe_weak_copy(src),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
auto dst_ptr = out.data<T>();
auto ind_ax_stride = ind.strides(axis);
auto src_ax_stride = src.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto ind_ax_size = ind.shape(axis);
auto src_ax_size = src.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= ind.shape(i);
}
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename IdxT>
void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(src);
encoder.set_input_array(inds);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
src = array::unsafe_weak_copy(src),
inds = array::unsafe_weak_copy(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename InT, typename IdxT, typename OpT>
void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_inds(
array& out,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT, None>(updates, out, indices, axes);
break;
case Scatter::Sum:
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
break;
case Scatter::Prod:
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
break;
case Scatter::Max:
scatter<InT, IdxT, Max>(updates, out, indices, axes);
break;
case Scatter::Min:
scatter<InT, IdxT, Min>(updates, out, indices, axes);
break;
}
}
template <typename InT>
void dispatch_scatter(
array& out,
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
default:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with indices type.");
}
}
void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
encoder.set_input_array(*it);
inds.push_back(array::unsafe_weak_copy(*it));
}
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
reduce_type_ = reduce_type_,
updates = array::unsafe_weak_copy(updates),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
});
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
auto dst_ptr = out.data<T>();
auto idx_ax_stride = idx.strides(axis);
auto upd_ax_stride = upd.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
OpT{}(
upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_axis_op(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
break;
}
}
template <typename InT>
void dispatch_scatter_axis(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
}
}
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
auto& idx = inputs[1];
auto& updates = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
reduce_type_ = reduce_type_,
idx = array::unsafe_weak_copy(idx),
updates = array::unsafe_weak_copy(updates),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(
out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
});
}
} // namespace mlx::core