mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Add remove_index utility (#2173)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
@@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = src.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||
@@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
|
||||
compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
|
||||
compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
|
Reference in New Issue
Block a user