mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add remove_index utility
This commit is contained in:
parent
659a51919f
commit
17bce1c092
@ -165,4 +165,11 @@ void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
|
@ -257,15 +257,11 @@ void gather_axis(
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
auto shape = remove_index(ind.shape(), axis);
|
||||
ContiguousIterator ind_it(
|
||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||
ContiguousIterator src_it(
|
||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
auto shape = remove_index(idx.shape(), axis);
|
||||
ContiguousIterator idx_it(
|
||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||
ContiguousIterator upd_it(
|
||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = src.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||
@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
|
Loading…
Reference in New Issue
Block a user