improvements to scatter / gather (#1541)

This commit is contained in:
Awni Hannun 2024-10-30 19:30:54 -07:00 committed by GitHub
parent 960e3f0f05
commit 4f72c66911
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 195 additions and 248 deletions

View File

@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
dst[tuple(idx)] = x
mx.eval(dst)
idx = []
@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
if device == torch.device("mps"):
torch.mps.synchronize()
@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(200_000,),
(20_000_000,),
(10000, 64),
(100, 64),
@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(hadamard)
if(MLX_METAL_JIT)

View File

@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
// Set all the buffers
@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
set_vector_bytes(compute_encoder, src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
set_vector_bytes(compute_encoder, slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
CopyType copy_type;
if (inputs[0].data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (inputs[0].flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(inputs[0], out, copy_type);
auto& upd = inputs.back();
// Empty update
if (inputs.back().size() == 0) {
if (upd.size() == 0) {
return;
}
@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
size_t idx_size = nidx ? inputs[1].size() : 1;
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
auto idx_to_out = idx_size / out.size();
int nwork;
if (idx_ndim <= 1 || idx_to_out < 1) {
nwork = 1;
} else if (idx_to_out <= 4) {
nwork = 4;
} else if (idx_to_out < 16) {
nwork = 8;
} else if (idx_to_out < 32) {
nwork = 16;
} else {
nwork = 32;
}
std::string lib_name;
@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "min";
break;
}
auto upd_contig = upd.flags().row_contiguous;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_type,
nidx,
idx_args,
idx_arr);
idx_arr,
upd_contig,
nwork);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
compute_encoder->setComputePipelineState(kernel);
@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set update info
uint upd_ndim = upd.ndim();
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
set_vector_bytes(compute_encoder, upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
set_vector_bytes(compute_encoder, out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
set_vector_bytes(compute_encoder, idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
auto grid_y = (nthreads / upd_size);
grid_y = (grid_y + nwork - 1) / nwork;
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
{4}
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}}
[[kernel]] void scatter{0}_{4}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
updates,
out,
upd_shape,
@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
out_strides,
out_ndim,
axes,
idx_size,
idxs,
gid);
}}

View File

@ -25,11 +25,13 @@ METAL_FUNC void gather_impl(
idx_loc = index.x * indices.strides[indices.ndim * i];
} else {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);

View File

@ -9,6 +9,7 @@ struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const constant bool* row_contiguous;
const int ndim;
};

View File

@ -4,73 +4,54 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
size_t out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
if (upd_ndim > 1) {
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
out_idx += out_offset;
} else {
out_idx += gid.x;
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
}
template <typename T, typename IdxT, typename Op, int NIDX>
template <
typename T,
typename IdxT,
typename Op,
int NIDX,
bool UPD_ROW_CONTIG,
int NWORK>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const device T* updates,
device mlx_atomic<T>* out,
const constant int* upd_shape,
const constant size_t* upd_strides,
const constant size_t& upd_ndim,
const constant size_t& upd_size,
const constant int* out_shape,
const constant size_t* out_strides,
const constant size_t& out_ndim,
const constant int* axes,
const constant size_t& idx_size,
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto ind_idx = gid.y * NWORK;
size_t out_offset = 0;
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
out_offset =
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
size_t out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto upd_idx = ind_idx * upd_size + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}
}

View File

@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
def _linear_indices(N, scale, align_corners, dim, ndims):
@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
(indices_l.astype(mx.uint32), 1 - weight),
(indices_r.astype(mx.uint32), weight),
)
@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
return (
(indices_l1.astype(mx.int32), weight_l1),
(indices_r1.astype(mx.int32), weight_r1),
(indices_l2.astype(mx.int32), weight_l2),
(indices_r2.astype(mx.int32), weight_r2),
(indices_l1.astype(mx.uint32), weight_l1),
(indices_r1.astype(mx.uint32), weight_r1),
(indices_l2.astype(mx.uint32), weight_l2),
(indices_r2.astype(mx.uint32), weight_r2),
)

View File

@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
a_mlx = mx.array(a_np)
if ax == None:
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
idx_np = np.random.permutation(a_np.size)
values_np = np.random.randint(low=0, high=100, size=(16,))
else:
shape = list(a_np.shape)
shape[ax] = 2
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
idx_np = np.broadcast_to(idx_np, shape)
values_np = np.random.randint(low=0, high=100, size=shape)
idx_np.astype(np.int32)