Add remove_index utility (#2173)

This commit is contained in:
Cheng
2025-05-14 09:09:56 +09:00
committed by GitHub
parent 3aa9cf3f9e
commit eca2f3eb97
4 changed files with 26 additions and 44 deletions

View File

@@ -257,15 +257,11 @@ void gather_axis(
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto shape = remove_index(ind.shape(), axis);
ContiguousIterator ind_it(
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
ContiguousIterator src_it(
shape, remove_index(src.strides(), axis), src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto shape = remove_index(idx.shape(), axis);
ContiguousIterator idx_it(
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
ContiguousIterator upd_it(
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();