mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	improvements to scatter / gather (#1541)
This commit is contained in:
		@@ -9,7 +9,7 @@ from time_utils import measure_runtime
 | 
			
		||||
 | 
			
		||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
 | 
			
		||||
    def scatter(dst, x, idx):
 | 
			
		||||
        dst[*idx] = x
 | 
			
		||||
        dst[tuple(idx)] = x
 | 
			
		||||
        mx.eval(dst)
 | 
			
		||||
 | 
			
		||||
    idx = []
 | 
			
		||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
 | 
			
		||||
    def gather(dst, x, idx, device):
 | 
			
		||||
        dst[*idx] = x
 | 
			
		||||
    def scatter(dst, x, idx, device):
 | 
			
		||||
        dst[tuple(idx)] = x
 | 
			
		||||
        if device == torch.device("mps"):
 | 
			
		||||
            torch.mps.synchronize()
 | 
			
		||||
 | 
			
		||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
 | 
			
		||||
    x = torch.randn(x_shape, dtype=torch.float32).to(device)
 | 
			
		||||
    dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
 | 
			
		||||
 | 
			
		||||
    runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
 | 
			
		||||
    runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
 | 
			
		||||
    print(f"PyTorch: {runtime:.3f}ms")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
 | 
			
		||||
        (100_000, 64),
 | 
			
		||||
        (1_000_000, 64),
 | 
			
		||||
        (100_000,),
 | 
			
		||||
        (2_000_00,),
 | 
			
		||||
        (200_000,),
 | 
			
		||||
        (20_000_000,),
 | 
			
		||||
        (10000, 64),
 | 
			
		||||
        (100, 64),
 | 
			
		||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
 | 
			
		||||
        print("=" * 20)
 | 
			
		||||
        print(f"X {x_shape}, Indices {idx_shape}")
 | 
			
		||||
        print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
 | 
			
		||||
        benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
 | 
			
		||||
        benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
 | 
			
		||||
 
 | 
			
		||||
@@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
 | 
			
		||||
make_jit_source(binary_ops)
 | 
			
		||||
make_jit_source(ternary_ops)
 | 
			
		||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
 | 
			
		||||
make_jit_source(scatter)
 | 
			
		||||
make_jit_source(gather)
 | 
			
		||||
make_jit_source(scatter kernels/indexing.h)
 | 
			
		||||
make_jit_source(gather kernels/indexing.h)
 | 
			
		||||
make_jit_source(hadamard)
 | 
			
		||||
 | 
			
		||||
if(MLX_METAL_JIT)
 | 
			
		||||
 
 | 
			
		||||
@@ -113,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  // Collect all idx shapes and strides into one place
 | 
			
		||||
  std::vector<int> idx_shapes;
 | 
			
		||||
  std::vector<size_t> idx_strides;
 | 
			
		||||
 | 
			
		||||
  std::vector<char> idx_contigs;
 | 
			
		||||
  for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
    idx_shapes.insert(
 | 
			
		||||
        idx_shapes.end(),
 | 
			
		||||
        inputs[i + 1].shape().begin(),
 | 
			
		||||
        inputs[i + 1].shape().end());
 | 
			
		||||
 | 
			
		||||
    idx_strides.insert(
 | 
			
		||||
        idx_strides.end(),
 | 
			
		||||
        inputs[i + 1].strides().begin(),
 | 
			
		||||
        inputs[i + 1].strides().end());
 | 
			
		||||
    idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Set all the buffers
 | 
			
		||||
@@ -131,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  compute_encoder.set_output_array(out, 1);
 | 
			
		||||
 | 
			
		||||
  // Set source info
 | 
			
		||||
  compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
 | 
			
		||||
  compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
 | 
			
		||||
  set_vector_bytes(compute_encoder, src.shape(), 2);
 | 
			
		||||
  set_vector_bytes(compute_encoder, src.strides(), 3);
 | 
			
		||||
  compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
 | 
			
		||||
  compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
 | 
			
		||||
  compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
 | 
			
		||||
  set_vector_bytes(compute_encoder, slice_sizes_, 5);
 | 
			
		||||
  set_vector_bytes(compute_encoder, axes_, 6);
 | 
			
		||||
 | 
			
		||||
  // Set index info
 | 
			
		||||
  //
 | 
			
		||||
  // We don't need to check for empty idx_shapes because gather has a
 | 
			
		||||
  // idx_ndim == 0 specialization
 | 
			
		||||
  compute_encoder->setBytes(
 | 
			
		||||
      idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
 | 
			
		||||
  compute_encoder->setBytes(
 | 
			
		||||
      idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
 | 
			
		||||
  compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_shapes, 7);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_strides, 8);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_contigs, 9);
 | 
			
		||||
  compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
 | 
			
		||||
 | 
			
		||||
  // Set index buffers
 | 
			
		||||
  for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
@@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Copy src into out
 | 
			
		||||
  auto copy_type =
 | 
			
		||||
      inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
 | 
			
		||||
  CopyType copy_type;
 | 
			
		||||
  if (inputs[0].data_size() == 1) {
 | 
			
		||||
    copy_type = CopyType::Scalar;
 | 
			
		||||
  } else if (inputs[0].flags().row_contiguous) {
 | 
			
		||||
    copy_type = CopyType::Vector;
 | 
			
		||||
  } else {
 | 
			
		||||
    copy_type = CopyType::General;
 | 
			
		||||
  }
 | 
			
		||||
  copy_gpu(inputs[0], out, copy_type);
 | 
			
		||||
 | 
			
		||||
  auto& upd = inputs.back();
 | 
			
		||||
 | 
			
		||||
  // Empty update
 | 
			
		||||
  if (inputs.back().size() == 0) {
 | 
			
		||||
  if (upd.size() == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  auto& d = metal::device(s.device);
 | 
			
		||||
 | 
			
		||||
  int idx_ndim = nidx ? inputs[1].ndim() : 0;
 | 
			
		||||
  bool index_nd1_specialization = (idx_ndim == 1);
 | 
			
		||||
  size_t idx_size = nidx ? inputs[1].size() : 1;
 | 
			
		||||
 | 
			
		||||
  // Bail from fast path (1d index specialization) if scatter dims aren't
 | 
			
		||||
  // the outermost dims and contiguous since update access won't be raster
 | 
			
		||||
  // order.
 | 
			
		||||
  for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
 | 
			
		||||
    index_nd1_specialization &= (axes_[i] == i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Bail from fast path (1d index specialization) if any of the dims are
 | 
			
		||||
  // broadcasted, since we can't rely on linear indexing in that case.
 | 
			
		||||
  for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
 | 
			
		||||
    index_nd1_specialization &= inputs[i].flags().row_contiguous;
 | 
			
		||||
  auto idx_to_out = idx_size / out.size();
 | 
			
		||||
  int nwork;
 | 
			
		||||
  if (idx_ndim <= 1 || idx_to_out < 1) {
 | 
			
		||||
    nwork = 1;
 | 
			
		||||
  } else if (idx_to_out <= 4) {
 | 
			
		||||
    nwork = 4;
 | 
			
		||||
  } else if (idx_to_out < 16) {
 | 
			
		||||
    nwork = 8;
 | 
			
		||||
  } else if (idx_to_out < 32) {
 | 
			
		||||
    nwork = 16;
 | 
			
		||||
  } else {
 | 
			
		||||
    nwork = 32;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string lib_name;
 | 
			
		||||
@@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
      op_name = "min";
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto upd_contig = upd.flags().row_contiguous;
 | 
			
		||||
  {
 | 
			
		||||
    std::ostringstream kname;
 | 
			
		||||
    if (index_nd1_specialization) {
 | 
			
		||||
      kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
 | 
			
		||||
    } else {
 | 
			
		||||
    kname << "scatter" << type_to_name(out) << idx_type_name;
 | 
			
		||||
    }
 | 
			
		||||
    kname << "_" << op_name << "_" << nidx;
 | 
			
		||||
    kname << "_" << op_name << "_" << nidx << "_"
 | 
			
		||||
          << (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
 | 
			
		||||
    lib_name = kname.str();
 | 
			
		||||
    kernel_name = kname.str();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto lib = d.get_library(lib_name, [&]() {
 | 
			
		||||
    std::ostringstream kernel_source;
 | 
			
		||||
    kernel_source << metal::utils() << metal::reduce_utils()
 | 
			
		||||
@@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
        op_type,
 | 
			
		||||
        nidx,
 | 
			
		||||
        idx_args,
 | 
			
		||||
        idx_arr);
 | 
			
		||||
        idx_arr,
 | 
			
		||||
        upd_contig,
 | 
			
		||||
        nwork);
 | 
			
		||||
    return kernel_source.str();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  auto& compute_encoder = d.get_command_encoder(s.index);
 | 
			
		||||
  auto kernel = d.get_kernel(kernel_name, lib);
 | 
			
		||||
 | 
			
		||||
  auto& upd = inputs.back();
 | 
			
		||||
  size_t nthreads = upd.size();
 | 
			
		||||
 | 
			
		||||
  compute_encoder->setComputePipelineState(kernel);
 | 
			
		||||
@@ -291,54 +296,27 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  compute_encoder.set_output_array(out, 2);
 | 
			
		||||
 | 
			
		||||
  // Set update info
 | 
			
		||||
  uint upd_ndim = upd.ndim();
 | 
			
		||||
  size_t upd_ndim = upd.ndim();
 | 
			
		||||
  size_t upd_size = 1;
 | 
			
		||||
  for (int i = idx_ndim; i < upd.ndim(); ++i) {
 | 
			
		||||
    upd_size *= upd.shape(i);
 | 
			
		||||
  }
 | 
			
		||||
  if (index_nd1_specialization) {
 | 
			
		||||
    compute_encoder->setBytes(
 | 
			
		||||
        out.shape().data(), out.shape().size() * sizeof(int), 3);
 | 
			
		||||
    compute_encoder->setBytes(
 | 
			
		||||
        out.strides().data(), out.strides().size() * sizeof(size_t), 4);
 | 
			
		||||
 | 
			
		||||
    size_t out_ndim = out.ndim();
 | 
			
		||||
    compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
 | 
			
		||||
    if (upd_ndim <= 1) {
 | 
			
		||||
      // Placeholder so Metal doesn't compalain
 | 
			
		||||
      int shape_ = 0;
 | 
			
		||||
      compute_encoder->setBytes(&shape_, sizeof(int), 6);
 | 
			
		||||
    } else {
 | 
			
		||||
      compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
 | 
			
		||||
    }
 | 
			
		||||
    compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
 | 
			
		||||
    compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
 | 
			
		||||
 | 
			
		||||
    // Set index buffers
 | 
			
		||||
    for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
      compute_encoder.set_input_array(inputs[i + 1], 20 + i);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Launch grid
 | 
			
		||||
    MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
 | 
			
		||||
    MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
 | 
			
		||||
    compute_encoder.dispatchThreads(grid_dims, group_dims);
 | 
			
		||||
 | 
			
		||||
  } else {
 | 
			
		||||
  // Collect all idx shapes and strides into one place
 | 
			
		||||
  std::vector<int> idx_shapes;
 | 
			
		||||
  std::vector<size_t> idx_strides;
 | 
			
		||||
 | 
			
		||||
  // To access .data() use char instead of bool
 | 
			
		||||
  // bool is 1 byte in Metal so this is safe
 | 
			
		||||
  std::vector<char> idx_contigs;
 | 
			
		||||
  for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
    idx_shapes.insert(
 | 
			
		||||
        idx_shapes.end(),
 | 
			
		||||
        inputs[i + 1].shape().begin(),
 | 
			
		||||
        inputs[i + 1].shape().end());
 | 
			
		||||
 | 
			
		||||
    idx_strides.insert(
 | 
			
		||||
        idx_strides.end(),
 | 
			
		||||
        inputs[i + 1].strides().begin(),
 | 
			
		||||
        inputs[i + 1].strides().end());
 | 
			
		||||
    idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (upd_ndim == 0) {
 | 
			
		||||
@@ -348,9 +326,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
    compute_encoder->setBytes(&shape_, sizeof(int), 3);
 | 
			
		||||
    compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
 | 
			
		||||
  } else {
 | 
			
		||||
      compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
 | 
			
		||||
      compute_encoder->setBytes(
 | 
			
		||||
          upd.strides().data(), upd_ndim * sizeof(size_t), 4);
 | 
			
		||||
    set_vector_bytes(compute_encoder, upd.shape(), 3);
 | 
			
		||||
    set_vector_bytes(compute_encoder, upd.strides(), 4);
 | 
			
		||||
  }
 | 
			
		||||
  compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
 | 
			
		||||
  compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
 | 
			
		||||
@@ -364,9 +341,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
    compute_encoder->setBytes(&shape_, sizeof(int), 7);
 | 
			
		||||
    compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
 | 
			
		||||
  } else {
 | 
			
		||||
      compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
 | 
			
		||||
      compute_encoder->setBytes(
 | 
			
		||||
          out.strides().data(), out_ndim * sizeof(size_t), 8);
 | 
			
		||||
    set_vector_bytes(compute_encoder, out.shape(), 7);
 | 
			
		||||
    set_vector_bytes(compute_encoder, out.strides(), 8);
 | 
			
		||||
  }
 | 
			
		||||
  compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
 | 
			
		||||
  compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
 | 
			
		||||
@@ -377,12 +353,13 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
    // error in the metal API.
 | 
			
		||||
    idx_shapes.push_back(0);
 | 
			
		||||
    idx_strides.push_back(0);
 | 
			
		||||
    idx_contigs.push_back(false);
 | 
			
		||||
  }
 | 
			
		||||
    compute_encoder->setBytes(
 | 
			
		||||
        idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
 | 
			
		||||
    compute_encoder->setBytes(
 | 
			
		||||
        idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
 | 
			
		||||
    compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_shapes, 11);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_strides, 12);
 | 
			
		||||
  set_vector_bytes(compute_encoder, idx_contigs, 13);
 | 
			
		||||
  compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
 | 
			
		||||
  compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
 | 
			
		||||
 | 
			
		||||
  // Set index buffers
 | 
			
		||||
  for (int i = 0; i < nidx; ++i) {
 | 
			
		||||
@@ -390,10 +367,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Launch grid
 | 
			
		||||
    MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
 | 
			
		||||
    MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
 | 
			
		||||
    compute_encoder.dispatchThreads(grid_dims, group_dims);
 | 
			
		||||
  auto grid_y = (nthreads / upd_size);
 | 
			
		||||
  grid_y = (grid_y + nwork - 1) / nwork;
 | 
			
		||||
  MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
 | 
			
		||||
  auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
 | 
			
		||||
  if (thread_group_size != 1024) {
 | 
			
		||||
    throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
 | 
			
		||||
  }
 | 
			
		||||
  MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
 | 
			
		||||
  compute_encoder.dispatchThreads(grid_dims, group_dims);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
@@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
 | 
			
		||||
    const constant int* axes [[buffer(6)]],
 | 
			
		||||
    const constant int* idx_shapes [[buffer(7)]],
 | 
			
		||||
    const constant size_t* idx_strides [[buffer(8)]],
 | 
			
		||||
    const constant int& idx_ndim [[buffer(9)]],
 | 
			
		||||
    const constant bool* idx_contigs [[buffer(9)]],
 | 
			
		||||
    const constant int& idx_ndim [[buffer(10)]],
 | 
			
		||||
    {4}
 | 
			
		||||
    uint3 index [[thread_position_in_grid]],
 | 
			
		||||
    uint3 grid_dim [[threads_per_grid]]) {{
 | 
			
		||||
  Indices<{2}, {3}> idxs{{
 | 
			
		||||
    {{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
 | 
			
		||||
    {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
 | 
			
		||||
 | 
			
		||||
  return gather_impl<{1}, {2}, {3}, {6}>(
 | 
			
		||||
      src,
 | 
			
		||||
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
constexpr std::string_view scatter_kernels = R"(
 | 
			
		||||
[[kernel]] void scatter_1d_index{0}_{4}(
 | 
			
		||||
    const device {1}* updates [[buffer(1)]],
 | 
			
		||||
    device mlx_atomic<{1}>* out [[buffer(2)]],
 | 
			
		||||
    const constant int* out_shape [[buffer(3)]],
 | 
			
		||||
    const constant size_t* out_strides [[buffer(4)]],
 | 
			
		||||
    const constant size_t& out_ndim [[buffer(5)]],
 | 
			
		||||
    const constant int* upd_shape [[buffer(6)]],
 | 
			
		||||
    const constant size_t& upd_ndim [[buffer(7)]],
 | 
			
		||||
    const constant size_t& upd_size [[buffer(8)]],
 | 
			
		||||
    {5}
 | 
			
		||||
    uint2 gid [[thread_position_in_grid]]) {{
 | 
			
		||||
  const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
 | 
			
		||||
  return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
 | 
			
		||||
      updates,
 | 
			
		||||
      out,
 | 
			
		||||
      out_shape,
 | 
			
		||||
      out_strides,
 | 
			
		||||
      out_ndim,
 | 
			
		||||
      upd_shape,
 | 
			
		||||
      upd_ndim,
 | 
			
		||||
      upd_size,
 | 
			
		||||
      idx_buffers,
 | 
			
		||||
      gid);
 | 
			
		||||
}}
 | 
			
		||||
 | 
			
		||||
[[kernel]] void scatter{0}_{4}(
 | 
			
		||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
 | 
			
		||||
    const device {1}* updates [[buffer(1)]],
 | 
			
		||||
    device mlx_atomic<{1}>* out [[buffer(2)]],
 | 
			
		||||
    const constant int* upd_shape [[buffer(3)]],
 | 
			
		||||
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
 | 
			
		||||
    const constant int* axes [[buffer(10)]],
 | 
			
		||||
    const constant int* idx_shapes [[buffer(11)]],
 | 
			
		||||
    const constant size_t* idx_strides [[buffer(12)]],
 | 
			
		||||
    const constant int& idx_ndim [[buffer(13)]],
 | 
			
		||||
    const constant bool* idx_contigs [[buffer(13)]],
 | 
			
		||||
    const constant int& idx_ndim [[buffer(14)]],
 | 
			
		||||
    const constant size_t& idx_size [[buffer(15)]],
 | 
			
		||||
    {5}
 | 
			
		||||
    uint2 gid [[thread_position_in_grid]]) {{
 | 
			
		||||
  Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
 | 
			
		||||
  Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
 | 
			
		||||
 | 
			
		||||
  return scatter_impl<{1}, {2}, {3}, {4}>(
 | 
			
		||||
  return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
 | 
			
		||||
      updates,
 | 
			
		||||
      out,
 | 
			
		||||
      upd_shape,
 | 
			
		||||
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
 | 
			
		||||
      out_strides,
 | 
			
		||||
      out_ndim,
 | 
			
		||||
      axes,
 | 
			
		||||
      idx_size,
 | 
			
		||||
      idxs,
 | 
			
		||||
      gid);
 | 
			
		||||
}}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,9 @@ METAL_FUNC void gather_impl(
 | 
			
		||||
      idx_loc = index.x * indices.strides[indices.ndim * i];
 | 
			
		||||
    } else {
 | 
			
		||||
      idx_loc = index.x * indices.strides[indices.ndim * i];
 | 
			
		||||
      idx_loc += elem_to_loc(
 | 
			
		||||
      idx_loc += indices.row_contiguous[i]
 | 
			
		||||
          ? index.y
 | 
			
		||||
          : elem_to_loc(
 | 
			
		||||
                index.y,
 | 
			
		||||
                &indices.shapes[indices.ndim * i + 1],
 | 
			
		||||
                &indices.strides[indices.ndim * i + 1],
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ struct Indices {
 | 
			
		||||
  const array<const device IdxT*, NIDX> buffers;
 | 
			
		||||
  const constant int* shapes;
 | 
			
		||||
  const constant size_t* strides;
 | 
			
		||||
  const constant bool* row_contiguous;
 | 
			
		||||
  const int ndim;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,57 +4,42 @@
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/kernels/indexing.h"
 | 
			
		||||
 | 
			
		||||
template <typename T, typename IdxT, typename Op, int NIDX>
 | 
			
		||||
METAL_FUNC void scatter_1d_index_impl(
 | 
			
		||||
    const device T* updates [[buffer(1)]],
 | 
			
		||||
    device mlx_atomic<T>* out [[buffer(2)]],
 | 
			
		||||
    const constant int* out_shape [[buffer(3)]],
 | 
			
		||||
    const constant size_t* out_strides [[buffer(4)]],
 | 
			
		||||
    const constant size_t& out_ndim [[buffer(5)]],
 | 
			
		||||
    const constant int* upd_shape [[buffer(6)]],
 | 
			
		||||
    const constant size_t& upd_ndim [[buffer(7)]],
 | 
			
		||||
    const constant size_t& upd_size [[buffer(8)]],
 | 
			
		||||
    const thread array<const device IdxT*, NIDX>& idx_buffers,
 | 
			
		||||
    uint2 gid [[thread_position_in_grid]]) {
 | 
			
		||||
  Op op;
 | 
			
		||||
 | 
			
		||||
  size_t out_idx = 0;
 | 
			
		||||
  for (int i = 0; i < NIDX; i++) {
 | 
			
		||||
    auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
 | 
			
		||||
    out_idx += idx_val * out_strides[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (upd_ndim > 1) {
 | 
			
		||||
    auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
 | 
			
		||||
    out_idx += out_offset;
 | 
			
		||||
  } else {
 | 
			
		||||
    out_idx += gid.x;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, typename IdxT, typename Op, int NIDX>
 | 
			
		||||
template <
 | 
			
		||||
    typename T,
 | 
			
		||||
    typename IdxT,
 | 
			
		||||
    typename Op,
 | 
			
		||||
    int NIDX,
 | 
			
		||||
    bool UPD_ROW_CONTIG,
 | 
			
		||||
    int NWORK>
 | 
			
		||||
METAL_FUNC void scatter_impl(
 | 
			
		||||
    const device T* updates [[buffer(1)]],
 | 
			
		||||
    device mlx_atomic<T>* out [[buffer(2)]],
 | 
			
		||||
    const constant int* upd_shape [[buffer(3)]],
 | 
			
		||||
    const constant size_t* upd_strides [[buffer(4)]],
 | 
			
		||||
    const constant size_t& upd_ndim [[buffer(5)]],
 | 
			
		||||
    const constant size_t& upd_size [[buffer(6)]],
 | 
			
		||||
    const constant int* out_shape [[buffer(7)]],
 | 
			
		||||
    const constant size_t* out_strides [[buffer(8)]],
 | 
			
		||||
    const constant size_t& out_ndim [[buffer(9)]],
 | 
			
		||||
    const constant int* axes [[buffer(10)]],
 | 
			
		||||
    const device T* updates,
 | 
			
		||||
    device mlx_atomic<T>* out,
 | 
			
		||||
    const constant int* upd_shape,
 | 
			
		||||
    const constant size_t* upd_strides,
 | 
			
		||||
    const constant size_t& upd_ndim,
 | 
			
		||||
    const constant size_t& upd_size,
 | 
			
		||||
    const constant int* out_shape,
 | 
			
		||||
    const constant size_t* out_strides,
 | 
			
		||||
    const constant size_t& out_ndim,
 | 
			
		||||
    const constant int* axes,
 | 
			
		||||
    const constant size_t& idx_size,
 | 
			
		||||
    const thread Indices<IdxT, NIDX>& indices,
 | 
			
		||||
    uint2 gid [[thread_position_in_grid]]) {
 | 
			
		||||
  Op op;
 | 
			
		||||
  auto ind_idx = gid.y;
 | 
			
		||||
  auto ind_offset = gid.x;
 | 
			
		||||
 | 
			
		||||
  size_t out_idx = 0;
 | 
			
		||||
  auto ind_idx = gid.y * NWORK;
 | 
			
		||||
  size_t out_offset = 0;
 | 
			
		||||
  if (upd_size > 1) {
 | 
			
		||||
    out_offset =
 | 
			
		||||
        elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
 | 
			
		||||
    size_t out_idx = out_offset;
 | 
			
		||||
    for (int i = 0; i < NIDX; ++i) {
 | 
			
		||||
    auto idx_loc = elem_to_loc(
 | 
			
		||||
      auto idx_loc = indices.row_contiguous[i]
 | 
			
		||||
          ? ind_idx
 | 
			
		||||
          : elem_to_loc(
 | 
			
		||||
                ind_idx,
 | 
			
		||||
                &indices.shapes[indices.ndim * i],
 | 
			
		||||
                &indices.strides[indices.ndim * i],
 | 
			
		||||
@@ -63,14 +48,10 @@ METAL_FUNC void scatter_impl(
 | 
			
		||||
      auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
 | 
			
		||||
      out_idx += idx_val * out_strides[ax];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  if (upd_size > 1) {
 | 
			
		||||
    auto out_offset = elem_to_loc(
 | 
			
		||||
        ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
 | 
			
		||||
    out_idx += out_offset;
 | 
			
		||||
    auto upd_idx = ind_idx * upd_size + gid.x;
 | 
			
		||||
    if constexpr (!UPD_ROW_CONTIG) {
 | 
			
		||||
      upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  auto upd_idx =
 | 
			
		||||
      elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
 | 
			
		||||
    op.atomic_update(out, updates[upd_idx], out_idx);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _nearest_indices(N, scale, dim, ndims):
 | 
			
		||||
    return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
 | 
			
		||||
    return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _linear_indices(N, scale, align_corners, dim, ndims):
 | 
			
		||||
@@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
 | 
			
		||||
    weight = mx.expand_dims(weight, -1)
 | 
			
		||||
 | 
			
		||||
    return (
 | 
			
		||||
        (indices_l.astype(mx.int32), 1 - weight),
 | 
			
		||||
        (indices_r.astype(mx.int32), weight),
 | 
			
		||||
        (indices_l.astype(mx.uint32), 1 - weight),
 | 
			
		||||
        (indices_r.astype(mx.uint32), weight),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
 | 
			
		||||
    indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
 | 
			
		||||
 | 
			
		||||
    return (
 | 
			
		||||
        (indices_l1.astype(mx.int32), weight_l1),
 | 
			
		||||
        (indices_r1.astype(mx.int32), weight_r1),
 | 
			
		||||
        (indices_l2.astype(mx.int32), weight_l2),
 | 
			
		||||
        (indices_r2.astype(mx.int32), weight_r2),
 | 
			
		||||
        (indices_l1.astype(mx.uint32), weight_l1),
 | 
			
		||||
        (indices_r1.astype(mx.uint32), weight_r1),
 | 
			
		||||
        (indices_l2.astype(mx.uint32), weight_l2),
 | 
			
		||||
        (indices_r2.astype(mx.uint32), weight_r2),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
            a_mlx = mx.array(a_np)
 | 
			
		||||
 | 
			
		||||
            if ax == None:
 | 
			
		||||
                idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
 | 
			
		||||
                idx_np = np.random.permutation(a_np.size)
 | 
			
		||||
                values_np = np.random.randint(low=0, high=100, size=(16,))
 | 
			
		||||
            else:
 | 
			
		||||
                shape = list(a_np.shape)
 | 
			
		||||
                shape[ax] = 2
 | 
			
		||||
                idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
 | 
			
		||||
                idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
 | 
			
		||||
                idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
 | 
			
		||||
                idx_np = np.broadcast_to(idx_np, shape)
 | 
			
		||||
                values_np = np.random.randint(low=0, high=100, size=shape)
 | 
			
		||||
 | 
			
		||||
            idx_np.astype(np.int32)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user