mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -16,11 +16,6 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
@@ -169,14 +164,11 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
@@ -201,12 +193,142 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval] Cannot gather with floating point indices.");
|
||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
template <typename T, typename IdxT>
|
||||
void gather_axis(
|
||||
const array& src,
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
auto dst_ptr = out.data<T>();
|
||||
auto ind_ax_stride = ind.strides(axis);
|
||||
auto src_ax_stride = src.strides(axis);
|
||||
auto dst_ax_stride = out.strides(axis);
|
||||
auto ind_ax_size = ind.shape(axis);
|
||||
auto src_ax_size = src.shape(axis);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
size_pre *= ind.shape(i);
|
||||
}
|
||||
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
||||
size_post *= ind.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < ind_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
|
||||
dst_ptr[k + j * dst_ax_stride] =
|
||||
src_ptr[src_it.loc + ind_val * src_ax_stride];
|
||||
}
|
||||
ind_it.step();
|
||||
src_it.step();
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
void dispatch_gather_axis(
|
||||
const array& src,
|
||||
const array& inds,
|
||||
array& out,
|
||||
const int axis) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint8:
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint16:
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint32:
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint64:
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int8:
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int16:
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int32:
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int64:
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float16:
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case complex64:
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -296,14 +418,11 @@ void dispatch_scatter(
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
@@ -328,12 +447,9 @@ void dispatch_scatter(
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||
"[Scatter::eval_cpu] Cannot scatter with indices type.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,7 +461,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
copy(src, out, CopyType::General);
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
@@ -390,4 +508,167 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(
|
||||
array& out,
|
||||
const array idx,
|
||||
const array& upd,
|
||||
int axis,
|
||||
const OpT& op) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
auto dst_ptr = out.data<T>();
|
||||
auto idx_ax_stride = idx.strides(axis);
|
||||
auto upd_ax_stride = upd.strides(axis);
|
||||
auto dst_ax_stride = out.strides(axis);
|
||||
auto idx_ax_size = idx.shape(axis);
|
||||
auto dst_ax_size = out.shape(axis);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * dst_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
}
|
||||
idx_it.step();
|
||||
upd_it.step();
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
void dispatch_scatter_axis_op(
|
||||
array& out,
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case ScatterAxis::None:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void dispatch_scatter_axis(
|
||||
array& out,
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (idx.dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
auto& updates = inputs[2];
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user