mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add remove_index utility
This commit is contained in:
parent
659a51919f
commit
17bce1c092
@ -165,4 +165,11 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||||
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = in.strides();
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = in.shape();
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
|
@ -257,15 +257,11 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto strides = ind.strides();
|
auto shape = remove_index(ind.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator ind_it(
|
||||||
auto shape = ind.shape();
|
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator src_it(
|
||||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||||
|
|
||||||
strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
auto strides = idx.strides();
|
auto shape = remove_index(idx.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator idx_it(
|
||||||
auto shape = idx.shape();
|
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator upd_it(
|
||||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||||
|
|
||||||
strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
auto shape = idx.shape();
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
shape.erase(shape.begin() + axis_);
|
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
|
||||||
compute_encoder.set_vector_bytes(shape, 3);
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
|
||||||
auto strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 4);
|
|
||||||
|
|
||||||
strides = idx.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 5);
|
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||||
@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
auto shape = idx.shape();
|
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||||
shape.erase(shape.begin() + axis_);
|
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||||
compute_encoder.set_vector_bytes(shape, 3);
|
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||||
|
|
||||||
auto strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 4);
|
|
||||||
|
|
||||||
strides = idx.strides();
|
|
||||||
strides.erase(strides.begin() + axis_);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 5);
|
|
||||||
compute_encoder.set_bytes(ndim - 1, 6);
|
compute_encoder.set_bytes(ndim - 1, 6);
|
||||||
compute_encoder.set_bytes(axis_, 7);
|
compute_encoder.set_bytes(axis_, 7);
|
||||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||||
|
Loading…
Reference in New Issue
Block a user