// Copyright © 2023 Apple Inc. #include #include #include #include "mlx/allocator.h" #include "mlx/primitives.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" namespace mlx::core { template inline size_t offset_neg_idx(IdxT idx, size_t size) { return (idx < 0) ? idx + size : idx; } template <> inline size_t offset_neg_idx(uint32_t idx, size_t) { return idx; } struct None { template void operator()(T x, T* y) { (*y) = x; } }; struct Sum { template void operator()(T x, T* y) { (*y) += x; } }; struct Prod { template void operator()(T x, T* y) { (*y) *= x; } }; struct Max { template void operator()(T x, T* y) { (*y) = (*y > x) ? *y : x; } }; struct Min { template void operator()(T x, T* y) { (*y) = (*y < x) ? *y : x; } }; template void gather( const array& src, const std::vector& inds, array& out, const std::vector& axes, const Shape& slice_sizes) { // If the array is row contiguous then we can do a contiguous copy given // two conditions on the slice size: // - Any number of leading ones in the slice sizes are allowed // - All other slice sizes match the corresponding dimension except the // first non-singleton slice size // If the array is col contiguous then the reverse is the case: // - Any number of trailing ones in the slice sizes are allowed // - All other slice sizes match the corresponding dimension except the // first non-singleton slice size from the end bool can_copy = false; if (src.flags().row_contiguous) { can_copy = true; // Ignore leading 1s int i = 0; for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) ; // Check the remaining i++; for (; i < src.ndim() && can_copy; ++i) { can_copy = (src.shape(i) == slice_sizes[i]); } } else if (src.flags().col_contiguous) { can_copy = true; // Ignore trailing 1s int i = slice_sizes.size() - 1; for (; i >= 0 && slice_sizes[i] == 1; --i) ; // Skip the next slice size and check the remaining i--; for (; i >= 0 && can_copy; --i) { can_copy = (src.shape(i) == slice_sizes[i]); } } size_t slice_size = 1; for (auto s : slice_sizes) { slice_size *= s; } size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; const T* src_ptr = src.data(); T* dst_ptr = out.data(); std::vector its(inds.begin(), inds.end()); ContiguousIterator src_it; if (!can_copy && src.ndim() > 0) { src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } size_t out_idx = 0; for (int idx = 0; idx < ind_size; idx++) { size_t src_idx = 0; for (int ii = 0; ii < inds.size(); ++ii) { auto ax = axes[ii]; auto idx_loc = its[ii].loc; its[ii].step(); auto idx_val = offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); src_idx += (idx_val * src.strides()[ax]); } if (slice_size == 1) { dst_ptr[out_idx++] = src_ptr[src_idx]; } else if (can_copy) { std::copy( src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); out_idx += slice_size; } else { for (int jj = 0; jj < slice_size; jj++) { dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; src_it.step(); } src_it.reset(); } } } template void dispatch_gather( const array& src, const std::vector& inds, array& out, const std::vector& axes, const Shape& size) { switch (out.dtype()) { case bool_: gather(src, inds, out, axes, size); break; case uint8: gather(src, inds, out, axes, size); break; case uint16: gather(src, inds, out, axes, size); break; case uint32: gather(src, inds, out, axes, size); break; case uint64: gather(src, inds, out, axes, size); break; case int8: gather(src, inds, out, axes, size); break; case int16: gather(src, inds, out, axes, size); break; case int32: gather(src, inds, out, axes, size); break; case int64: gather(src, inds, out, axes, size); break; case float16: gather(src, inds, out, axes, size); break; case float32: gather(src, inds, out, axes, size); break; case float64: gather(src, inds, out, axes, size); break; case bfloat16: gather(src, inds, out, axes, size); break; case complex64: gather(src, inds, out, axes, size); break; } } void Gather::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; std::vector inds; for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) { inds.push_back(array::unsafe_weak_copy(*it)); } auto& encoder = cpu::get_command_encoder(stream()); for (auto& in : inputs) { encoder.set_input_array(in); } encoder.set_output_array(out); encoder.dispatch([axes_ = axes_, slice_sizes_ = slice_sizes_, src = array::unsafe_weak_copy(src), inds = std::move(inds), out = array::unsafe_weak_copy(out)]() mutable { if (inds.empty()) { dispatch_gather(src, inds, out, axes_, slice_sizes_); return; } switch (inds[0].dtype()) { case uint8: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint16: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint32: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case uint64: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int8: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int16: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int32: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; case int64: dispatch_gather(src, inds, out, axes_, slice_sizes_); break; default: throw std::runtime_error( "[Gather::eval_cpu] Cannot gather with indices type."); break; } }); } template void gather_axis( const array& src, const array& ind, array& out, const int axis) { auto strides = ind.strides(); strides.erase(strides.begin() + axis); auto shape = ind.shape(); shape.erase(shape.begin() + axis); ContiguousIterator ind_it(shape, strides, src.ndim() - 1); strides = src.strides(); strides.erase(strides.begin() + axis); ContiguousIterator src_it(shape, strides, src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); auto dst_ptr = out.data(); auto ind_ax_stride = ind.strides(axis); auto src_ax_stride = src.strides(axis); auto dst_ax_stride = out.strides(axis); auto ind_ax_size = ind.shape(axis); auto src_ax_size = src.shape(axis); size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { size_pre *= ind.shape(i); } for (int i = axis + 1; i < ind.ndim(); ++i) { size_post *= ind.shape(i); } size_t stride_pre = size_post * ind_ax_size; for (size_t i = 0; i < size_pre; i++) { for (size_t k = 0; k < size_post; k++) { for (int j = 0; j < ind_ax_size; ++j) { auto ind_val = offset_neg_idx( ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size); dst_ptr[k + j * dst_ax_stride] = src_ptr[src_it.loc + ind_val * src_ax_stride]; } ind_it.step(); src_it.step(); } dst_ptr += stride_pre; } } template void dispatch_gather_axis( const array& src, const array& inds, array& out, const int axis) { switch (out.dtype()) { case bool_: gather_axis(src, inds, out, axis); break; case uint8: gather_axis(src, inds, out, axis); break; case uint16: gather_axis(src, inds, out, axis); break; case uint32: gather_axis(src, inds, out, axis); break; case uint64: gather_axis(src, inds, out, axis); break; case int8: gather_axis(src, inds, out, axis); break; case int16: gather_axis(src, inds, out, axis); break; case int32: gather_axis(src, inds, out, axis); break; case int64: gather_axis(src, inds, out, axis); break; case float16: gather_axis(src, inds, out, axis); break; case float32: gather_axis(src, inds, out, axis); break; case float64: gather_axis(src, inds, out, axis); break; case bfloat16: gather_axis(src, inds, out, axis); break; case complex64: gather_axis(src, inds, out, axis); break; } } void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; auto& inds = inputs[1]; auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(src); encoder.set_input_array(inds); encoder.set_output_array(out); encoder.dispatch([axis_ = axis_, src = array::unsafe_weak_copy(src), inds = array::unsafe_weak_copy(inds), out = array::unsafe_weak_copy(out)]() mutable { switch (inds.dtype()) { case uint8: dispatch_gather_axis(src, inds, out, axis_); break; case uint16: dispatch_gather_axis(src, inds, out, axis_); break; case uint32: dispatch_gather_axis(src, inds, out, axis_); break; case uint64: dispatch_gather_axis(src, inds, out, axis_); break; case int8: dispatch_gather_axis(src, inds, out, axis_); break; case int16: dispatch_gather_axis(src, inds, out, axis_); break; case int32: dispatch_gather_axis(src, inds, out, axis_); break; case int64: dispatch_gather_axis(src, inds, out, axis_); break; default: throw std::runtime_error( "[GatherAxis::eval_cpu] Cannot gather with indices type."); break; } }); } template void scatter( const array& updates, array& out, const std::vector& inds, const std::vector& axes) { int nind = inds.size(); auto inds_ndim = updates.ndim() - out.ndim(); size_t n_updates = nind ? inds[0].size() : 1; Shape update_shape( updates.shape().begin() + inds_ndim, updates.shape().end()); size_t update_size = 1; for (auto us : update_shape) { update_size *= us; } std::vector its(inds.begin(), inds.end()); ContiguousIterator update_it(updates); ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); auto out_ptr = out.data(); auto upd_ptr = updates.data(); for (int i = 0; i < n_updates; ++i) { size_t out_offset = 0; for (int j = 0; j < inds.size(); ++j) { auto ax = axes[j]; auto idx_loc = its[j].loc; its[j].step(); auto idx_val = offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); out_offset += (idx_val * out.strides()[ax]); } update_it.seek(i * update_size); for (int j = 0; j < update_size; ++j) { OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); update_it.step(); out_it.step(); } out_it.reset(); update_it.reset(); } } template void dispatch_scatter_inds( array& out, const std::vector& indices, const array& updates, const std::vector& axes, Scatter::ReduceType rtype) { switch (rtype) { case Scatter::None: scatter(updates, out, indices, axes); break; case Scatter::Sum: scatter(updates, out, indices, axes); break; case Scatter::Prod: scatter(updates, out, indices, axes); break; case Scatter::Max: scatter(updates, out, indices, axes); break; case Scatter::Min: scatter(updates, out, indices, axes); break; } } template void dispatch_scatter( array& out, const std::vector& inds, const array& updates, const std::vector& axes, Scatter::ReduceType rtype) { if (inds.empty()) { dispatch_scatter_inds(out, inds, updates, axes, rtype); return; } switch (inds[0].dtype()) { case uint8: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint16: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint32: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case uint64: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int8: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int16: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int32: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; case int64: dispatch_scatter_inds(out, inds, updates, axes, rtype); break; default: throw std::runtime_error( "[Scatter::eval_cpu] Cannot scatter with indices type."); } } void Scatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; auto& updates = inputs.back(); // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); std::vector inds; for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) { encoder.set_input_array(*it); inds.push_back(array::unsafe_weak_copy(*it)); } encoder.set_input_array(updates); encoder.set_output_array(out); encoder.dispatch([axes_ = axes_, reduce_type_ = reduce_type_, updates = array::unsafe_weak_copy(updates), inds = std::move(inds), out = array::unsafe_weak_copy(out)]() mutable { switch (out.dtype()) { case bool_: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint8: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case uint64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int8: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case int64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case float64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case bfloat16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; case complex64: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; } }); } template void scatter_axis(array& out, const array idx, const array& upd, int axis) { auto strides = idx.strides(); strides.erase(strides.begin() + axis); auto shape = idx.shape(); shape.erase(shape.begin() + axis); ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); strides = upd.strides(); strides.erase(strides.begin() + axis); ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); auto dst_ptr = out.data(); auto idx_ax_stride = idx.strides(axis); auto upd_ax_stride = upd.strides(axis); auto dst_ax_stride = out.strides(axis); auto idx_ax_size = idx.shape(axis); auto dst_ax_size = out.shape(axis); size_t size_pre = 1; size_t size_post = 1; for (int i = 0; i < axis; ++i) { size_pre *= idx.shape(i); } for (int i = axis + 1; i < idx.ndim(); ++i) { size_post *= idx.shape(i); } size_t stride_pre = size_post * dst_ax_size; for (size_t i = 0; i < size_pre; i++) { for (size_t k = 0; k < size_post; k++) { for (int j = 0; j < idx_ax_size; ++j) { auto ind_val = offset_neg_idx( idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size); OpT{}( upd_ptr[upd_it.loc + j * upd_ax_stride], dst_ptr + k + ind_val * dst_ax_stride); } idx_it.step(); upd_it.step(); } dst_ptr += stride_pre; } } template void dispatch_scatter_axis_op( array& out, const array& idx, const array& updates, int axis, ScatterAxis::ReduceType rtype) { switch (rtype) { case ScatterAxis::None: scatter_axis(out, idx, updates, axis); break; case ScatterAxis::Sum: scatter_axis(out, idx, updates, axis); break; } } template void dispatch_scatter_axis( array& out, const array& idx, const array& updates, int axis, ScatterAxis::ReduceType rtype) { switch (idx.dtype()) { case uint8: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint16: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint32: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case uint64: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int8: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int16: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int32: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; case int64: dispatch_scatter_axis_op(out, idx, updates, axis, rtype); break; default: throw std::runtime_error( "[ScatterAxis::eval_cpu] Cannot scatter with indices type."); } } void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() >= 2); auto& src = inputs[0]; auto& idx = inputs[1]; auto& updates = inputs[2]; // Copy src into out (copy allocates memory for out) auto ctype = src.flags().row_contiguous ? CopyType::Vector : CopyType::General; copy(src, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(idx); encoder.set_input_array(updates); encoder.set_output_array(out); encoder.dispatch([axis_ = axis_, reduce_type_ = reduce_type_, idx = array::unsafe_weak_copy(idx), updates = array::unsafe_weak_copy(updates), out = array::unsafe_weak_copy(out)]() mutable { switch (out.dtype()) { case bool_: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint8: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint16: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case uint64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int8: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int16: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case int64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case float16: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; case float32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case float64: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; case bfloat16: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; case complex64: dispatch_scatter_axis( out, idx, updates, axis_, reduce_type_); break; } }); } } // namespace mlx::core