mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
improvements to scatter / gather (#1541)
This commit is contained in:
parent
960e3f0f05
commit
4f72c66911
@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[*idx] = x
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[*idx] = x
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
@ -91,6 +91,6 @@ if __name__ == "__main__":
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||
|
@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
|
@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
set_vector_bytes(compute_encoder, src.shape(), 2);
|
||||
set_vector_bytes(compute_encoder, src.strides(), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
set_vector_bytes(compute_encoder, slice_sizes_, 5);
|
||||
set_vector_bytes(compute_encoder, axes_, 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 7);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 8);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 9);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Copy src into out
|
||||
auto copy_type =
|
||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Empty update
|
||||
if (inputs.back().size() == 0) {
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
auto idx_to_out = idx_size / out.size();
|
||||
int nwork;
|
||||
if (idx_ndim <= 1 || idx_to_out < 1) {
|
||||
nwork = 1;
|
||||
} else if (idx_to_out <= 4) {
|
||||
nwork = 4;
|
||||
} else if (idx_to_out < 16) {
|
||||
nwork = 8;
|
||||
} else if (idx_to_out < 32) {
|
||||
nwork = 16;
|
||||
} else {
|
||||
nwork = 32;
|
||||
}
|
||||
|
||||
std::string lib_name;
|
||||
@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
kname << "_" << op_name << "_" << nidx << "_"
|
||||
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork);
|
||||
return kernel_source.str();
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@ -291,54 +296,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (index_nd1_specialization) {
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
// To access .data() use char instead of bool
|
||||
// bool is 1 byte in Metal so this is safe
|
||||
std::vector<char> idx_contigs;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
@ -348,9 +326,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
set_vector_bytes(compute_encoder, upd.shape(), 3);
|
||||
set_vector_bytes(compute_encoder, upd.strides(), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
@ -364,9 +341,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
set_vector_bytes(compute_encoder, out.shape(), 7);
|
||||
set_vector_bytes(compute_encoder, out.strides(), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
@ -377,12 +353,13 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
idx_contigs.push_back(false);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, idx_shapes, 11);
|
||||
set_vector_bytes(compute_encoder, idx_strides, 12);
|
||||
set_vector_bytes(compute_encoder, idx_contigs, 13);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
|
||||
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
@ -390,10 +367,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
auto grid_y = (nthreads / upd_size);
|
||||
grid_y = (grid_y + nwork - 1) / nwork;
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
const constant bool* idx_contigs [[buffer(9)]],
|
||||
const constant int& idx_ndim [[buffer(10)]],
|
||||
{4}
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
const constant bool* idx_contigs [[buffer(13)]],
|
||||
const constant int& idx_ndim [[buffer(14)]],
|
||||
const constant size_t& idx_size [[buffer(15)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idx_size,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
|
@ -25,7 +25,9 @@ METAL_FUNC void gather_impl(
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
idx_loc += indices.row_contiguous[i]
|
||||
? index.y
|
||||
: elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
|
@ -9,6 +9,7 @@ struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const constant bool* row_contiguous;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
|
@ -4,57 +4,42 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (upd_ndim > 1) {
|
||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
} else {
|
||||
out_idx += gid.x;
|
||||
}
|
||||
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
bool UPD_ROW_CONTIG,
|
||||
int NWORK>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T* updates [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const device T* updates,
|
||||
device mlx_atomic<T>* out,
|
||||
const constant int* upd_shape,
|
||||
const constant size_t* upd_strides,
|
||||
const constant size_t& upd_ndim,
|
||||
const constant size_t& upd_size,
|
||||
const constant int* out_shape,
|
||||
const constant size_t* out_strides,
|
||||
const constant size_t& out_ndim,
|
||||
const constant int* axes,
|
||||
const constant size_t& idx_size,
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
auto ind_idx = gid.y * NWORK;
|
||||
size_t out_offset = 0;
|
||||
if (upd_size > 1) {
|
||||
out_offset =
|
||||
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
}
|
||||
|
||||
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||
size_t out_idx = out_offset;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
auto idx_loc = indices.row_contiguous[i]
|
||||
? ind_idx
|
||||
: elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
@ -63,14 +48,10 @@ METAL_FUNC void scatter_impl(
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
auto upd_idx = ind_idx * upd_size + gid.x;
|
||||
if constexpr (!UPD_ROW_CONTIG) {
|
||||
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||
}
|
||||
|
||||
auto upd_idx =
|
||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
(indices_l.astype(mx.uint32), 1 - weight),
|
||||
(indices_r.astype(mx.uint32), weight),
|
||||
)
|
||||
|
||||
|
||||
@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||
|
||||
return (
|
||||
(indices_l1.astype(mx.int32), weight_l1),
|
||||
(indices_r1.astype(mx.int32), weight_r1),
|
||||
(indices_l2.astype(mx.int32), weight_l2),
|
||||
(indices_r2.astype(mx.int32), weight_r2),
|
||||
(indices_l1.astype(mx.uint32), weight_l1),
|
||||
(indices_r1.astype(mx.uint32), weight_r1),
|
||||
(indices_l2.astype(mx.uint32), weight_l2),
|
||||
(indices_r2.astype(mx.uint32), weight_r2),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_mlx = mx.array(a_np)
|
||||
|
||||
if ax == None:
|
||||
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
|
||||
idx_np = np.random.permutation(a_np.size)
|
||||
values_np = np.random.randint(low=0, high=100, size=(16,))
|
||||
else:
|
||||
shape = list(a_np.shape)
|
||||
shape[ax] = 2
|
||||
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
|
||||
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
|
||||
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
|
||||
idx_np = np.broadcast_to(idx_np, shape)
|
||||
values_np = np.random.randint(low=0, high=100, size=shape)
|
||||
|
||||
idx_np.astype(np.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user