mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
improvements to scatter / gather (#1541)
This commit is contained in:
parent
960e3f0f05
commit
4f72c66911
@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
|||||||
|
|
||||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||||
def scatter(dst, x, idx):
|
def scatter(dst, x, idx):
|
||||||
dst[*idx] = x
|
dst[tuple(idx)] = x
|
||||||
mx.eval(dst)
|
mx.eval(dst)
|
||||||
|
|
||||||
idx = []
|
idx = []
|
||||||
@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
|||||||
|
|
||||||
|
|
||||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||||
def gather(dst, x, idx, device):
|
def scatter(dst, x, idx, device):
|
||||||
dst[*idx] = x
|
dst[tuple(idx)] = x
|
||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
|||||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||||
print(f"PyTorch: {runtime:.3f}ms")
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||||||
(100_000, 64),
|
(100_000, 64),
|
||||||
(1_000_000, 64),
|
(1_000_000, 64),
|
||||||
(100_000,),
|
(100_000,),
|
||||||
(2_000_00,),
|
(200_000,),
|
||||||
(20_000_000,),
|
(20_000_000,),
|
||||||
(10000, 64),
|
(10000, 64),
|
||||||
(100, 64),
|
(100, 64),
|
||||||
@ -91,6 +91,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
print("=" * 20)
|
print("=" * 20)
|
||||||
print(f"X {x_shape}, Indices {idx_shape}")
|
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||||
|
@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
|||||||
make_jit_source(binary_ops)
|
make_jit_source(binary_ops)
|
||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||||
make_jit_source(scatter)
|
make_jit_source(scatter kernels/indexing.h)
|
||||||
make_jit_source(gather)
|
make_jit_source(gather kernels/indexing.h)
|
||||||
make_jit_source(hadamard)
|
make_jit_source(hadamard)
|
||||||
|
|
||||||
if(MLX_METAL_JIT)
|
if(MLX_METAL_JIT)
|
||||||
|
@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Collect all idx shapes and strides into one place
|
// Collect all idx shapes and strides into one place
|
||||||
std::vector<int> idx_shapes;
|
std::vector<int> idx_shapes;
|
||||||
std::vector<size_t> idx_strides;
|
std::vector<size_t> idx_strides;
|
||||||
|
std::vector<char> idx_contigs;
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
idx_shapes.insert(
|
idx_shapes.insert(
|
||||||
idx_shapes.end(),
|
idx_shapes.end(),
|
||||||
inputs[i + 1].shape().begin(),
|
inputs[i + 1].shape().begin(),
|
||||||
inputs[i + 1].shape().end());
|
inputs[i + 1].shape().end());
|
||||||
|
|
||||||
idx_strides.insert(
|
idx_strides.insert(
|
||||||
idx_strides.end(),
|
idx_strides.end(),
|
||||||
inputs[i + 1].strides().begin(),
|
inputs[i + 1].strides().begin(),
|
||||||
inputs[i + 1].strides().end());
|
inputs[i + 1].strides().end());
|
||||||
|
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set all the buffers
|
// Set all the buffers
|
||||||
@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
// Set source info
|
// Set source info
|
||||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
set_vector_bytes(compute_encoder, src.shape(), 2);
|
||||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
set_vector_bytes(compute_encoder, src.strides(), 3);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
set_vector_bytes(compute_encoder, slice_sizes_, 5);
|
||||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
set_vector_bytes(compute_encoder, axes_, 6);
|
||||||
|
|
||||||
// Set index info
|
// Set index info
|
||||||
//
|
//
|
||||||
// We don't need to check for empty idx_shapes because gather has a
|
// We don't need to check for empty idx_shapes because gather has a
|
||||||
// idx_ndim == 0 specialization
|
// idx_ndim == 0 specialization
|
||||||
compute_encoder->setBytes(
|
set_vector_bytes(compute_encoder, idx_shapes, 7);
|
||||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
set_vector_bytes(compute_encoder, idx_strides, 8);
|
||||||
compute_encoder->setBytes(
|
set_vector_bytes(compute_encoder, idx_contigs, 9);
|
||||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
|
||||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
|
||||||
|
|
||||||
// Set index buffers
|
// Set index buffers
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy src into out
|
// Copy src into out
|
||||||
auto copy_type =
|
CopyType copy_type;
|
||||||
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
if (inputs[0].data_size() == 1) {
|
||||||
|
copy_type = CopyType::Scalar;
|
||||||
|
} else if (inputs[0].flags().row_contiguous) {
|
||||||
|
copy_type = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
copy_type = CopyType::General;
|
||||||
|
}
|
||||||
copy_gpu(inputs[0], out, copy_type);
|
copy_gpu(inputs[0], out, copy_type);
|
||||||
|
|
||||||
|
auto& upd = inputs.back();
|
||||||
|
|
||||||
// Empty update
|
// Empty update
|
||||||
if (inputs.back().size() == 0) {
|
if (upd.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
bool index_nd1_specialization = (idx_ndim == 1);
|
size_t idx_size = nidx ? inputs[1].size() : 1;
|
||||||
|
|
||||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
auto idx_to_out = idx_size / out.size();
|
||||||
// the outermost dims and contiguous since update access won't be raster
|
int nwork;
|
||||||
// order.
|
if (idx_ndim <= 1 || idx_to_out < 1) {
|
||||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
nwork = 1;
|
||||||
index_nd1_specialization &= (axes_[i] == i);
|
} else if (idx_to_out <= 4) {
|
||||||
}
|
nwork = 4;
|
||||||
|
} else if (idx_to_out < 16) {
|
||||||
// Bail from fast path (1d index specialization) if any of the dims are
|
nwork = 8;
|
||||||
// broadcasted, since we can't rely on linear indexing in that case.
|
} else if (idx_to_out < 32) {
|
||||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
nwork = 16;
|
||||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
} else {
|
||||||
|
nwork = 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string lib_name;
|
std::string lib_name;
|
||||||
@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
op_name = "min";
|
op_name = "min";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
auto upd_contig = upd.flags().row_contiguous;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
if (index_nd1_specialization) {
|
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
kname << "_" << op_name << "_" << nidx << "_"
|
||||||
} else {
|
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
|
||||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
|
||||||
}
|
|
||||||
kname << "_" << op_name << "_" << nidx;
|
|
||||||
lib_name = kname.str();
|
lib_name = kname.str();
|
||||||
kernel_name = kname.str();
|
kernel_name = kname.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::reduce_utils()
|
kernel_source << metal::utils() << metal::reduce_utils()
|
||||||
@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
op_type,
|
op_type,
|
||||||
nidx,
|
nidx,
|
||||||
idx_args,
|
idx_args,
|
||||||
idx_arr);
|
idx_arr,
|
||||||
|
upd_contig,
|
||||||
|
nwork);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
|
||||||
auto& upd = inputs.back();
|
|
||||||
size_t nthreads = upd.size();
|
size_t nthreads = upd.size();
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Set update info
|
// Set update info
|
||||||
uint upd_ndim = upd.ndim();
|
size_t upd_ndim = upd.ndim();
|
||||||
size_t upd_size = 1;
|
size_t upd_size = 1;
|
||||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||||
upd_size *= upd.shape(i);
|
upd_size *= upd.shape(i);
|
||||||
}
|
}
|
||||||
if (index_nd1_specialization) {
|
// Collect all idx shapes and strides into one place
|
||||||
compute_encoder->setBytes(
|
std::vector<int> idx_shapes;
|
||||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
std::vector<size_t> idx_strides;
|
||||||
compute_encoder->setBytes(
|
// To access .data() use char instead of bool
|
||||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
// bool is 1 byte in Metal so this is safe
|
||||||
|
std::vector<char> idx_contigs;
|
||||||
size_t out_ndim = out.ndim();
|
for (int i = 0; i < nidx; ++i) {
|
||||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
idx_shapes.insert(
|
||||||
if (upd_ndim <= 1) {
|
idx_shapes.end(),
|
||||||
// Placeholder so Metal doesn't compalain
|
inputs[i + 1].shape().begin(),
|
||||||
int shape_ = 0;
|
inputs[i + 1].shape().end());
|
||||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
idx_strides.insert(
|
||||||
} else {
|
idx_strides.end(),
|
||||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
inputs[i + 1].strides().begin(),
|
||||||
}
|
inputs[i + 1].strides().end());
|
||||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
|
||||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
|
||||||
|
|
||||||
// Set index buffers
|
|
||||||
for (int i = 0; i < nidx; ++i) {
|
|
||||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch grid
|
|
||||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
|
||||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// Collect all idx shapes and strides into one place
|
|
||||||
std::vector<int> idx_shapes;
|
|
||||||
std::vector<size_t> idx_strides;
|
|
||||||
|
|
||||||
for (int i = 0; i < nidx; ++i) {
|
|
||||||
idx_shapes.insert(
|
|
||||||
idx_shapes.end(),
|
|
||||||
inputs[i + 1].shape().begin(),
|
|
||||||
inputs[i + 1].shape().end());
|
|
||||||
|
|
||||||
idx_strides.insert(
|
|
||||||
idx_strides.end(),
|
|
||||||
inputs[i + 1].strides().begin(),
|
|
||||||
inputs[i + 1].strides().end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (upd_ndim == 0) {
|
|
||||||
// Need placeholders so Metal doesn't compalain
|
|
||||||
int shape_ = 0;
|
|
||||||
size_t stride_ = 0;
|
|
||||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
|
||||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
|
||||||
} else {
|
|
||||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
|
||||||
compute_encoder->setBytes(
|
|
||||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
|
||||||
}
|
|
||||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
|
||||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
|
||||||
|
|
||||||
// Set output info
|
|
||||||
size_t out_ndim = out.ndim();
|
|
||||||
if (out_ndim == 0) {
|
|
||||||
// Need placeholders so Metal doesn't compalain
|
|
||||||
int shape_ = 0;
|
|
||||||
size_t stride_ = 0;
|
|
||||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
|
||||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
|
||||||
} else {
|
|
||||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
|
||||||
compute_encoder->setBytes(
|
|
||||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
|
||||||
}
|
|
||||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
|
||||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
|
||||||
|
|
||||||
// Set index info
|
|
||||||
if (idx_ndim == 0) {
|
|
||||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
|
||||||
// error in the metal API.
|
|
||||||
idx_shapes.push_back(0);
|
|
||||||
idx_strides.push_back(0);
|
|
||||||
}
|
|
||||||
compute_encoder->setBytes(
|
|
||||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
|
||||||
compute_encoder->setBytes(
|
|
||||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
|
||||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
|
||||||
|
|
||||||
// Set index buffers
|
|
||||||
for (int i = 0; i < nidx; ++i) {
|
|
||||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch grid
|
|
||||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
|
||||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (upd_ndim == 0) {
|
||||||
|
// Need placeholders so Metal doesn't compalain
|
||||||
|
int shape_ = 0;
|
||||||
|
size_t stride_ = 0;
|
||||||
|
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||||
|
} else {
|
||||||
|
set_vector_bytes(compute_encoder, upd.shape(), 3);
|
||||||
|
set_vector_bytes(compute_encoder, upd.strides(), 4);
|
||||||
|
}
|
||||||
|
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||||
|
|
||||||
|
// Set output info
|
||||||
|
size_t out_ndim = out.ndim();
|
||||||
|
if (out_ndim == 0) {
|
||||||
|
// Need placeholders so Metal doesn't compalain
|
||||||
|
int shape_ = 0;
|
||||||
|
size_t stride_ = 0;
|
||||||
|
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||||
|
} else {
|
||||||
|
set_vector_bytes(compute_encoder, out.shape(), 7);
|
||||||
|
set_vector_bytes(compute_encoder, out.strides(), 8);
|
||||||
|
}
|
||||||
|
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||||
|
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||||
|
|
||||||
|
// Set index info
|
||||||
|
if (idx_ndim == 0) {
|
||||||
|
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||||
|
// error in the metal API.
|
||||||
|
idx_shapes.push_back(0);
|
||||||
|
idx_strides.push_back(0);
|
||||||
|
idx_contigs.push_back(false);
|
||||||
|
}
|
||||||
|
set_vector_bytes(compute_encoder, idx_shapes, 11);
|
||||||
|
set_vector_bytes(compute_encoder, idx_strides, 12);
|
||||||
|
set_vector_bytes(compute_encoder, idx_contigs, 13);
|
||||||
|
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
|
||||||
|
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
|
||||||
|
|
||||||
|
// Set index buffers
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch grid
|
||||||
|
auto grid_y = (nthreads / upd_size);
|
||||||
|
grid_y = (grid_y + nwork - 1) / nwork;
|
||||||
|
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
|
||||||
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
|
|||||||
const constant int* axes [[buffer(6)]],
|
const constant int* axes [[buffer(6)]],
|
||||||
const constant int* idx_shapes [[buffer(7)]],
|
const constant int* idx_shapes [[buffer(7)]],
|
||||||
const constant size_t* idx_strides [[buffer(8)]],
|
const constant size_t* idx_strides [[buffer(8)]],
|
||||||
const constant int& idx_ndim [[buffer(9)]],
|
const constant bool* idx_contigs [[buffer(9)]],
|
||||||
|
const constant int& idx_ndim [[buffer(10)]],
|
||||||
{4}
|
{4}
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {{
|
uint3 grid_dim [[threads_per_grid]]) {{
|
||||||
Indices<{2}, {3}> idxs{{
|
Indices<{2}, {3}> idxs{{
|
||||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||||
|
|
||||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||||
src,
|
src,
|
||||||
@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
constexpr std::string_view scatter_kernels = R"(
|
constexpr std::string_view scatter_kernels = R"(
|
||||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
||||||
const device {1}* updates [[buffer(1)]],
|
|
||||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
|
||||||
const constant int* out_shape [[buffer(3)]],
|
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
|
||||||
const constant size_t& out_ndim [[buffer(5)]],
|
|
||||||
const constant int* upd_shape [[buffer(6)]],
|
|
||||||
const constant size_t& upd_ndim [[buffer(7)]],
|
|
||||||
const constant size_t& upd_size [[buffer(8)]],
|
|
||||||
{5}
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {{
|
|
||||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
|
||||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
|
||||||
updates,
|
|
||||||
out,
|
|
||||||
out_shape,
|
|
||||||
out_strides,
|
|
||||||
out_ndim,
|
|
||||||
upd_shape,
|
|
||||||
upd_ndim,
|
|
||||||
upd_size,
|
|
||||||
idx_buffers,
|
|
||||||
gid);
|
|
||||||
}}
|
|
||||||
|
|
||||||
[[kernel]] void scatter{0}_{4}(
|
|
||||||
const device {1}* updates [[buffer(1)]],
|
const device {1}* updates [[buffer(1)]],
|
||||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||||
const constant int* upd_shape [[buffer(3)]],
|
const constant int* upd_shape [[buffer(3)]],
|
||||||
@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
|
|||||||
const constant int* axes [[buffer(10)]],
|
const constant int* axes [[buffer(10)]],
|
||||||
const constant int* idx_shapes [[buffer(11)]],
|
const constant int* idx_shapes [[buffer(11)]],
|
||||||
const constant size_t* idx_strides [[buffer(12)]],
|
const constant size_t* idx_strides [[buffer(12)]],
|
||||||
const constant int& idx_ndim [[buffer(13)]],
|
const constant bool* idx_contigs [[buffer(13)]],
|
||||||
|
const constant int& idx_ndim [[buffer(14)]],
|
||||||
|
const constant size_t& idx_size [[buffer(15)]],
|
||||||
{5}
|
{5}
|
||||||
uint2 gid [[thread_position_in_grid]]) {{
|
uint2 gid [[thread_position_in_grid]]) {{
|
||||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||||
|
|
||||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
||||||
updates,
|
updates,
|
||||||
out,
|
out,
|
||||||
upd_shape,
|
upd_shape,
|
||||||
@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
|
|||||||
out_strides,
|
out_strides,
|
||||||
out_ndim,
|
out_ndim,
|
||||||
axes,
|
axes,
|
||||||
|
idx_size,
|
||||||
idxs,
|
idxs,
|
||||||
gid);
|
gid);
|
||||||
}}
|
}}
|
||||||
|
@ -25,11 +25,13 @@ METAL_FUNC void gather_impl(
|
|||||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||||
} else {
|
} else {
|
||||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||||
idx_loc += elem_to_loc(
|
idx_loc += indices.row_contiguous[i]
|
||||||
index.y,
|
? index.y
|
||||||
&indices.shapes[indices.ndim * i + 1],
|
: elem_to_loc(
|
||||||
&indices.strides[indices.ndim * i + 1],
|
index.y,
|
||||||
indices.ndim - 1);
|
&indices.shapes[indices.ndim * i + 1],
|
||||||
|
&indices.strides[indices.ndim * i + 1],
|
||||||
|
indices.ndim - 1);
|
||||||
}
|
}
|
||||||
auto ax = axes[i];
|
auto ax = axes[i];
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
|
@ -9,6 +9,7 @@ struct Indices {
|
|||||||
const array<const device IdxT*, NIDX> buffers;
|
const array<const device IdxT*, NIDX> buffers;
|
||||||
const constant int* shapes;
|
const constant int* shapes;
|
||||||
const constant size_t* strides;
|
const constant size_t* strides;
|
||||||
|
const constant bool* row_contiguous;
|
||||||
const int ndim;
|
const int ndim;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -4,73 +4,54 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/indexing.h"
|
#include "mlx/backend/metal/kernels/indexing.h"
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
template <
|
||||||
METAL_FUNC void scatter_1d_index_impl(
|
typename T,
|
||||||
const device T* updates [[buffer(1)]],
|
typename IdxT,
|
||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
typename Op,
|
||||||
const constant int* out_shape [[buffer(3)]],
|
int NIDX,
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
bool UPD_ROW_CONTIG,
|
||||||
const constant size_t& out_ndim [[buffer(5)]],
|
int NWORK>
|
||||||
const constant int* upd_shape [[buffer(6)]],
|
|
||||||
const constant size_t& upd_ndim [[buffer(7)]],
|
|
||||||
const constant size_t& upd_size [[buffer(8)]],
|
|
||||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
|
|
||||||
size_t out_idx = 0;
|
|
||||||
for (int i = 0; i < NIDX; i++) {
|
|
||||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
|
||||||
out_idx += idx_val * out_strides[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (upd_ndim > 1) {
|
|
||||||
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
|
|
||||||
out_idx += out_offset;
|
|
||||||
} else {
|
|
||||||
out_idx += gid.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
|
||||||
METAL_FUNC void scatter_impl(
|
METAL_FUNC void scatter_impl(
|
||||||
const device T* updates [[buffer(1)]],
|
const device T* updates,
|
||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
device mlx_atomic<T>* out,
|
||||||
const constant int* upd_shape [[buffer(3)]],
|
const constant int* upd_shape,
|
||||||
const constant size_t* upd_strides [[buffer(4)]],
|
const constant size_t* upd_strides,
|
||||||
const constant size_t& upd_ndim [[buffer(5)]],
|
const constant size_t& upd_ndim,
|
||||||
const constant size_t& upd_size [[buffer(6)]],
|
const constant size_t& upd_size,
|
||||||
const constant int* out_shape [[buffer(7)]],
|
const constant int* out_shape,
|
||||||
const constant size_t* out_strides [[buffer(8)]],
|
const constant size_t* out_strides,
|
||||||
const constant size_t& out_ndim [[buffer(9)]],
|
const constant size_t& out_ndim,
|
||||||
const constant int* axes [[buffer(10)]],
|
const constant int* axes,
|
||||||
|
const constant size_t& idx_size,
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
Op op;
|
Op op;
|
||||||
auto ind_idx = gid.y;
|
|
||||||
auto ind_offset = gid.x;
|
|
||||||
|
|
||||||
size_t out_idx = 0;
|
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
|
||||||
auto idx_loc = elem_to_loc(
|
|
||||||
ind_idx,
|
|
||||||
&indices.shapes[indices.ndim * i],
|
|
||||||
&indices.strides[indices.ndim * i],
|
|
||||||
indices.ndim);
|
|
||||||
auto ax = axes[i];
|
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
|
||||||
out_idx += idx_val * out_strides[ax];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
auto ind_idx = gid.y * NWORK;
|
||||||
|
size_t out_offset = 0;
|
||||||
if (upd_size > 1) {
|
if (upd_size > 1) {
|
||||||
auto out_offset = elem_to_loc(
|
out_offset =
|
||||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
out_idx += out_offset;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto upd_idx =
|
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
||||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
size_t out_idx = out_offset;
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
auto idx_loc = indices.row_contiguous[i]
|
||||||
|
? ind_idx
|
||||||
|
: elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
auto ax = axes[i];
|
||||||
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
|
out_idx += idx_val * out_strides[ax];
|
||||||
|
}
|
||||||
|
auto upd_idx = ind_idx * upd_size + gid.x;
|
||||||
|
if constexpr (!UPD_ROW_CONTIG) {
|
||||||
|
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
|
||||||
|
}
|
||||||
|
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
|||||||
|
|
||||||
|
|
||||||
def _nearest_indices(N, scale, dim, ndims):
|
def _nearest_indices(N, scale, dim, ndims):
|
||||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||||
|
|
||||||
|
|
||||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||||
@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
|||||||
weight = mx.expand_dims(weight, -1)
|
weight = mx.expand_dims(weight, -1)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
(indices_l.astype(mx.int32), 1 - weight),
|
(indices_l.astype(mx.uint32), 1 - weight),
|
||||||
(indices_r.astype(mx.int32), weight),
|
(indices_r.astype(mx.uint32), weight),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
|
|||||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
(indices_l1.astype(mx.int32), weight_l1),
|
(indices_l1.astype(mx.uint32), weight_l1),
|
||||||
(indices_r1.astype(mx.int32), weight_r1),
|
(indices_r1.astype(mx.uint32), weight_r1),
|
||||||
(indices_l2.astype(mx.int32), weight_l2),
|
(indices_l2.astype(mx.uint32), weight_l2),
|
||||||
(indices_r2.astype(mx.int32), weight_r2),
|
(indices_r2.astype(mx.uint32), weight_r2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
a_mlx = mx.array(a_np)
|
a_mlx = mx.array(a_np)
|
||||||
|
|
||||||
if ax == None:
|
if ax == None:
|
||||||
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
|
idx_np = np.random.permutation(a_np.size)
|
||||||
values_np = np.random.randint(low=0, high=100, size=(16,))
|
values_np = np.random.randint(low=0, high=100, size=(16,))
|
||||||
else:
|
else:
|
||||||
shape = list(a_np.shape)
|
shape = list(a_np.shape)
|
||||||
shape[ax] = 2
|
shape[ax] = 2
|
||||||
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
|
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
|
||||||
|
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
|
||||||
|
idx_np = np.broadcast_to(idx_np, shape)
|
||||||
values_np = np.random.randint(low=0, high=100, size=shape)
|
values_np = np.random.randint(low=0, high=100, size=shape)
|
||||||
|
|
||||||
idx_np.astype(np.int32)
|
idx_np.astype(np.int32)
|
||||||
|
Loading…
Reference in New Issue
Block a user