mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter * Use constant address space for shapes and strides * Split gather and scatter to improve compile times * Enable the GPU tests * Update the CI config * Fix scatter dispatch for scalar indices * Remove arg encoder utils --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -215,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { | ||||
|   return eit->second; | ||||
| } | ||||
|  | ||||
| MTL::ArgumentEncoder* Device::argument_encoder( | ||||
|     const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const { | ||||
|   // NB array here is already autoreleased but the returned argument | ||||
|   // encoder is owned by the caller and must be released/autoreleased | ||||
|   NS::Array* arg_desc_arr = NS::Array::array( | ||||
|       reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size()); | ||||
|   return device_->newArgumentEncoder(arg_desc_arr); | ||||
| } | ||||
|  | ||||
| void Device::register_library( | ||||
|     const std::string& lib_name, | ||||
|     const std::string& lib_path) { | ||||
|   | ||||
| @@ -51,6 +51,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   auto compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname.str()); | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
|   size_t slice_size = 1; | ||||
|   for (auto s : slice_sizes_) { | ||||
| @@ -63,91 +64,50 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto group_dims = get_block_dims(dim0, dim1, 1); | ||||
|   MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|   // Collect all idx shapes and strides into one place | ||||
|   std::vector<int> idx_shapes; | ||||
|   std::vector<size_t> idx_strides; | ||||
|  | ||||
|   // Make the argument buffer to store the indices for the | ||||
|   // `Indices` struct in kernels/indexing.metal | ||||
|   std::vector<MTL::ArgumentDescriptor*> arg_descs(4); | ||||
|   arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[0]->setIndex(0); | ||||
|   arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[0]->setArrayLength(nidx); | ||||
|  | ||||
|   // Shapes | ||||
|   arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[1]->setIndex(nidx + 1); | ||||
|  | ||||
|   // Strides | ||||
|   arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[2]->setIndex(nidx + 2); | ||||
|  | ||||
|   // Indices ndim | ||||
|   arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); | ||||
|   arg_descs[3]->setIndex(nidx + 3); | ||||
|  | ||||
|   // Get the argument encoder | ||||
|   auto arg_enc = d.argument_encoder(arg_descs); | ||||
|  | ||||
|   // Allocate and fill buffers for shapes and strides | ||||
|   auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); | ||||
|   auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); | ||||
|   for (int i = 0; i < nidx; ++i) { | ||||
|     std::copy( | ||||
|     idx_shapes.insert( | ||||
|         idx_shapes.end(), | ||||
|         inputs[i + 1].shape().begin(), | ||||
|         inputs[i + 1].shape().end(), | ||||
|         static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim); | ||||
|     std::copy( | ||||
|         inputs[i + 1].shape().end()); | ||||
|  | ||||
|     idx_strides.insert( | ||||
|         idx_strides.end(), | ||||
|         inputs[i + 1].strides().begin(), | ||||
|         inputs[i + 1].strides().end(), | ||||
|         static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim); | ||||
|         inputs[i + 1].strides().end()); | ||||
|   } | ||||
|  | ||||
|   // Allocate the argument buffer | ||||
|   auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); | ||||
|  | ||||
|   // Register data with the encoder | ||||
|   arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0); | ||||
|   for (int i = 0; i < nidx; ++i) { | ||||
|     set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); | ||||
|   } | ||||
|   if (idx_ndim > 0) { | ||||
|     arg_enc->setBuffer( | ||||
|         static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1); | ||||
|     compute_encoder->useResource( | ||||
|         static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), | ||||
|         MTL::ResourceUsageRead); | ||||
|     arg_enc->setBuffer( | ||||
|         static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2); | ||||
|     compute_encoder->useResource( | ||||
|         static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), | ||||
|         MTL::ResourceUsageRead); | ||||
|   } | ||||
|   *static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim; | ||||
|  | ||||
|   // Set all the buffers | ||||
|   set_array_buffer(compute_encoder, src, 0); | ||||
|   compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1); | ||||
|   set_array_buffer(compute_encoder, out, 2); | ||||
|   set_array_buffer(compute_encoder, out, 1); | ||||
|  | ||||
|   compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); | ||||
|   compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); | ||||
|   compute_encoder->setBytes(&ndim, sizeof(size_t), 5); | ||||
|   compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); | ||||
|   compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); | ||||
|   // Set source info | ||||
|   compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); | ||||
|   compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3); | ||||
|   compute_encoder->setBytes(&ndim, sizeof(size_t), 4); | ||||
|   compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5); | ||||
|   compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6); | ||||
|  | ||||
|   // Set index info | ||||
|   // | ||||
|   // We don't need to check for empty idx_shapes because gather has a | ||||
|   // idx_ndim == 0 specialization | ||||
|   compute_encoder->setBytes( | ||||
|       idx_shapes.data(), idx_shapes.size() * sizeof(int), 7); | ||||
|   compute_encoder->setBytes( | ||||
|       idx_strides.data(), idx_strides.size() * sizeof(size_t), 8); | ||||
|   compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); | ||||
|  | ||||
|   // Set index buffers | ||||
|   for (int i = 1; i < nidx + 1; ++i) { | ||||
|     set_array_buffer(compute_encoder, inputs[i], 20 + i); | ||||
|   } | ||||
|  | ||||
|   // Launch grid | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   // Cleanup temporaries | ||||
|   arg_enc->release(); | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { | ||||
|         allocator::free(arg_buf); | ||||
|         allocator::free(idx_shapes_buf); | ||||
|         allocator::free(idx_strides_buf); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -214,77 +174,33 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
|   // Make the argument buffer to store the indices for the | ||||
|   // `Indices` struct in kernels/indexing.metal | ||||
|   std::vector<MTL::ArgumentDescriptor*> arg_descs(4); | ||||
|   arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[0]->setIndex(0); | ||||
|   arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[0]->setArrayLength(nidx); | ||||
|  | ||||
|   // Shapes | ||||
|   arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[1]->setIndex(nidx + 1); | ||||
|  | ||||
|   // Strides | ||||
|   arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); | ||||
|   arg_descs[2]->setIndex(nidx + 2); | ||||
|  | ||||
|   // Indices ndim | ||||
|   arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); | ||||
|   arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); | ||||
|   arg_descs[3]->setIndex(nidx + 3); | ||||
|  | ||||
|   // Get the argument encoder | ||||
|   auto arg_enc = d.argument_encoder(arg_descs); | ||||
|  | ||||
|   // Allocate and fill buffers for shapes and strides | ||||
|   // Collect all idx shapes and strides into one place | ||||
|   int idx_ndim = nidx ? inputs[1].ndim() : 0; | ||||
|   auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); | ||||
|   auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); | ||||
|   std::vector<int> idx_shapes; | ||||
|   std::vector<size_t> idx_strides; | ||||
|  | ||||
|   for (int i = 0; i < nidx; ++i) { | ||||
|     std::copy( | ||||
|     idx_shapes.insert( | ||||
|         idx_shapes.end(), | ||||
|         inputs[i + 1].shape().begin(), | ||||
|         inputs[i + 1].shape().end(), | ||||
|         static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim); | ||||
|     std::copy( | ||||
|         inputs[i + 1].shape().end()); | ||||
|  | ||||
|     idx_strides.insert( | ||||
|         idx_strides.end(), | ||||
|         inputs[i + 1].strides().begin(), | ||||
|         inputs[i + 1].strides().end(), | ||||
|         static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim); | ||||
|         inputs[i + 1].strides().end()); | ||||
|   } | ||||
|  | ||||
|   // Allocate the argument buffer | ||||
|   auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); | ||||
|   // Set all the buffers | ||||
|   set_array_buffer(compute_encoder, upd, 1); | ||||
|   set_array_buffer(compute_encoder, out, 2); | ||||
|  | ||||
|   // Register data with the encoder | ||||
|   arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0); | ||||
|   for (int i = 0; i < nidx; ++i) { | ||||
|     set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); | ||||
|   } | ||||
|   if (idx_ndim > 0) { | ||||
|     arg_enc->setBuffer( | ||||
|         static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1); | ||||
|     compute_encoder->useResource( | ||||
|         static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), | ||||
|         MTL::ResourceUsageRead); | ||||
|     arg_enc->setBuffer( | ||||
|         static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2); | ||||
|     compute_encoder->useResource( | ||||
|         static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), | ||||
|         MTL::ResourceUsageRead); | ||||
|   } | ||||
|   *static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim; | ||||
|  | ||||
|   compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0); | ||||
|   // Set update info | ||||
|   size_t upd_ndim = upd.ndim(); | ||||
|   size_t upd_size = 1; | ||||
|   for (int i = idx_ndim; i < upd.ndim(); ++i) { | ||||
|     upd_size *= upd.shape(i); | ||||
|   } | ||||
|   set_array_buffer(compute_encoder, upd, 1); | ||||
|   set_array_buffer(compute_encoder, out, 2); | ||||
|   if (upd_ndim == 0) { | ||||
|     // Need placeholders so Metal doesn't compalain | ||||
|     int shape_ = 0; | ||||
| @@ -299,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); | ||||
|   compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); | ||||
|  | ||||
|   // Set output info | ||||
|   size_t out_ndim = out.ndim(); | ||||
|   if (out_ndim == 0) { | ||||
|     // Need placeholders so Metal doesn't compalain | ||||
| @@ -314,18 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); | ||||
|   compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); | ||||
|  | ||||
|   // Set index info | ||||
|   if (idx_ndim == 0) { | ||||
|     // Add a 0 in idx_shapes and strides to avoid the missing buffer binding | ||||
|     // error in the metal API. | ||||
|     idx_shapes.push_back(0); | ||||
|     idx_strides.push_back(0); | ||||
|   } | ||||
|   compute_encoder->setBytes( | ||||
|       idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); | ||||
|   compute_encoder->setBytes( | ||||
|       idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); | ||||
|   compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); | ||||
|  | ||||
|   // Set index buffers | ||||
|   for (int i = 1; i < nidx + 1; ++i) { | ||||
|     set_array_buffer(compute_encoder, inputs[i], 20 + i); | ||||
|   } | ||||
|  | ||||
|   // Launch grid | ||||
|   MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); | ||||
|   MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   // Cleanup temporaries | ||||
|   arg_enc->release(); | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { | ||||
|         allocator::free(arg_buf); | ||||
|         allocator::free(idx_shapes_buf); | ||||
|         allocator::free(idx_strides_buf); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
| @@ -6,6 +6,7 @@ set( | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/complex.h | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/defines.h | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/erf.h | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/utils.h | ||||
| ) | ||||
| @@ -26,7 +27,8 @@ set( | ||||
|   "softmax" | ||||
|   "sort" | ||||
|   "unary" | ||||
|   "indexing" | ||||
|   "gather" | ||||
|   "scatter" | ||||
| ) | ||||
|  | ||||
| function(build_kernel_base TARGET SRCFILE DEPS) | ||||
|   | ||||
							
								
								
									
										187
									
								
								mlx/backend/metal/kernels/gather.metal
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										187
									
								
								mlx/backend/metal/kernels/gather.metal
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,187 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <metal_atomic> | ||||
|  | ||||
| #include "mlx/backend/metal/kernels/bf16.h" | ||||
| #include "mlx/backend/metal/kernels/indexing.h" | ||||
| #include "mlx/backend/metal/kernels/utils.h" | ||||
|  | ||||
| using namespace metal; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Gather kernel | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename T, typename IdxT, int NIDX, int IDX_NDIM> | ||||
| METAL_FUNC void gather_impl( | ||||
|     const device T *src [[buffer(0)]], | ||||
|     device T *out [[buffer(1)]], | ||||
|     const constant int *src_shape [[buffer(2)]], | ||||
|     const constant size_t *src_strides [[buffer(3)]], | ||||
|     const constant size_t& src_ndim [[buffer(4)]], | ||||
|     const constant int *slice_sizes [[buffer(5)]], | ||||
|     const constant int *axes [[buffer(6)]], | ||||
|     const thread Indices<IdxT, NIDX>& indices, | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|  | ||||
|   auto ind_idx = index.x; | ||||
|   auto ind_offset = index.y; | ||||
|  | ||||
|   size_t src_idx = 0; | ||||
|   for (int i = 0; i < NIDX; ++i) { | ||||
|     size_t idx_loc; | ||||
|     if (IDX_NDIM == 0) { | ||||
|       idx_loc = 0; | ||||
|     } else if (IDX_NDIM == 1) { | ||||
|       idx_loc = ind_idx * indices.strides[indices.ndim * i]; | ||||
|     } else { | ||||
|       idx_loc = elem_to_loc( | ||||
|           ind_idx, | ||||
|           &indices.shapes[indices.ndim * i], | ||||
|           &indices.strides[indices.ndim * i], | ||||
|           indices.ndim); | ||||
|     } | ||||
|     auto ax = axes[i]; | ||||
|     auto idx_val = offset_neg_idx( | ||||
|         indices.buffers[i][idx_loc], src_shape[ax]); | ||||
|     src_idx += idx_val * src_strides[ax]; | ||||
|   } | ||||
|  | ||||
|   auto src_offset = elem_to_loc( | ||||
|       ind_offset, slice_sizes, src_strides, src_ndim); | ||||
|  | ||||
|   size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x; | ||||
|   out[out_idx] = src[src_offset + src_idx]; | ||||
|  | ||||
| } | ||||
|  | ||||
| #define make_gather_impl(IDX_ARG, IDX_ARR) \ | ||||
| template <typename T, typename IdxT, int NIDX, int IDX_NDIM>  \ | ||||
| [[kernel]] void gather( \ | ||||
|     const device T *src [[buffer(0)]], \ | ||||
|     device T *out [[buffer(1)]], \ | ||||
|     const constant int *src_shape [[buffer(2)]], \ | ||||
|     const constant size_t *src_strides [[buffer(3)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(4)]], \ | ||||
|     const constant int *slice_sizes [[buffer(5)]], \ | ||||
|     const constant int *axes [[buffer(6)]], \ | ||||
|     const constant int *idx_shapes [[buffer(7)]], \ | ||||
|     const constant size_t *idx_strides [[buffer(8)]], \ | ||||
|     const constant int& idx_ndim [[buffer(9)]], \ | ||||
|     IDX_ARG(IdxT) \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]) { \ | ||||
|  \ | ||||
|   Indices<IdxT, NIDX> idxs{ \ | ||||
|       {{IDX_ARR()}}, \ | ||||
|       idx_shapes, \ | ||||
|       idx_strides, \ | ||||
|       idx_ndim}; \ | ||||
|  \ | ||||
|   return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \ | ||||
|       src, \ | ||||
|       out, \ | ||||
|       src_shape, \ | ||||
|       src_strides, \ | ||||
|       src_ndim, \ | ||||
|       slice_sizes, \ | ||||
|       axes, \ | ||||
|       idxs, \ | ||||
|       index, \ | ||||
|       grid_dim); \ | ||||
| }  | ||||
|  | ||||
| #define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) | ||||
|  | ||||
| make_gather(0) | ||||
| make_gather(1) | ||||
| make_gather(2) | ||||
| make_gather(3) | ||||
| make_gather(4) | ||||
| make_gather(5) | ||||
| make_gather(6) | ||||
| make_gather(7) | ||||
| make_gather(8) | ||||
| make_gather(9) | ||||
| make_gather(10) | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Gather instantiations | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| #define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ | ||||
| template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ | ||||
| [[kernel]] void gather<src_t, idx_t, nidx, nd>( \ | ||||
|     const device src_t *src [[buffer(0)]], \ | ||||
|     device src_t *out [[buffer(1)]], \ | ||||
|     const constant int *src_shape [[buffer(2)]], \ | ||||
|     const constant size_t *src_strides [[buffer(3)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(4)]], \ | ||||
|     const constant int *slice_sizes [[buffer(5)]], \ | ||||
|     const constant int *axes [[buffer(6)]], \ | ||||
|     const constant int *idx_shapes [[buffer(7)]], \ | ||||
|     const constant size_t *idx_strides [[buffer(8)]], \ | ||||
|     const constant int& idx_ndim [[buffer(9)]], \ | ||||
|     IDX_ARG(idx_t) \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); | ||||
|  | ||||
| #define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ | ||||
|   instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) | ||||
|  | ||||
| #define instantiate_gather4(name, src_t, idx_t, nidx) \ | ||||
|   instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ | ||||
|   instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ | ||||
|   instantiate_gather5(name, src_t, idx_t, nidx, 2, ) | ||||
|  | ||||
|  | ||||
| // Special for case NIDX=0 | ||||
| instantiate_gather4("bool_", bool, bool, 0) | ||||
| instantiate_gather4("uint8", uint8_t, bool, 0) | ||||
| instantiate_gather4("uint16", uint16_t, bool, 0) | ||||
| instantiate_gather4("uint32", uint32_t, bool, 0) | ||||
| instantiate_gather4("uint64", uint64_t, bool, 0) | ||||
| instantiate_gather4("int8", int8_t, bool, 0) | ||||
| instantiate_gather4("int16", int16_t, bool, 0) | ||||
| instantiate_gather4("int32", int32_t, bool, 0) | ||||
| instantiate_gather4("int64", int64_t, bool, 0) | ||||
| instantiate_gather4("float16", half, bool, 0) | ||||
| instantiate_gather4("float32", float, bool, 0) | ||||
| instantiate_gather4("bfloat16", bfloat16_t, bool, 0) | ||||
|  | ||||
| #define instantiate_gather3(name, src_type, ind_type) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 1) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 2) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 3) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 4) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 5) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 6) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 7) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 8) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 9) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 10) | ||||
|  | ||||
| #define instantiate_gather(name, src_type) \ | ||||
|   instantiate_gather3(#name "bool_", src_type, bool) \ | ||||
|   instantiate_gather3(#name "uint8", src_type, uint8_t) \ | ||||
|   instantiate_gather3(#name "uint16", src_type, uint16_t) \ | ||||
|   instantiate_gather3(#name "uint32", src_type, uint32_t) \ | ||||
|   instantiate_gather3(#name "uint64", src_type, uint64_t) \ | ||||
|   instantiate_gather3(#name "int8", src_type, int8_t) \ | ||||
|   instantiate_gather3(#name "int16", src_type, int16_t) \ | ||||
|   instantiate_gather3(#name "int32", src_type, int32_t) \ | ||||
|   instantiate_gather3(#name "int64", src_type, int64_t) | ||||
|  | ||||
| instantiate_gather(bool_, bool) | ||||
| instantiate_gather(uint8, uint8_t) | ||||
| instantiate_gather(uint16, uint16_t) | ||||
| instantiate_gather(uint32, uint32_t) | ||||
| instantiate_gather(uint64, uint64_t) | ||||
| instantiate_gather(int8, int8_t) | ||||
| instantiate_gather(int16, int16_t) | ||||
| instantiate_gather(int32, int32_t) | ||||
| instantiate_gather(int64, int64_t) | ||||
| instantiate_gather(float16, half) | ||||
| instantiate_gather(float32, float) | ||||
| instantiate_gather(bfloat16, bfloat16_t) | ||||
							
								
								
									
										54
									
								
								mlx/backend/metal/kernels/indexing.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								mlx/backend/metal/kernels/indexing.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <metal_stdlib> | ||||
|  | ||||
| using namespace metal; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Indexing utils | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename IdxT, int NIDX> | ||||
| struct Indices { | ||||
|   const array<const device IdxT*, NIDX> buffers; | ||||
|   const constant int* shapes; | ||||
|   const constant size_t* strides; | ||||
|   const int ndim; | ||||
| }; | ||||
|  | ||||
| template <typename IdxT> | ||||
| METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { | ||||
|   if (is_unsigned_v<IdxT>) { | ||||
|     return idx; | ||||
|   } else { | ||||
|     return (idx < 0) ? idx + size : idx; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]], | ||||
|  | ||||
| #define IDX_ARG_0(idx_t) | ||||
| #define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21) | ||||
| #define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22) | ||||
| #define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23) | ||||
| #define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24) | ||||
| #define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25) | ||||
| #define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26) | ||||
| #define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27) | ||||
| #define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28) | ||||
| #define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29) | ||||
| #define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30) | ||||
|  | ||||
| #define IDX_ARR_N(n) idx##n, | ||||
|  | ||||
| #define IDX_ARR_0() | ||||
| #define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21) | ||||
| #define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22) | ||||
| #define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23) | ||||
| #define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24) | ||||
| #define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25) | ||||
| #define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26) | ||||
| #define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27) | ||||
| #define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28) | ||||
| #define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29) | ||||
| #define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30) | ||||
| @@ -1,290 +0,0 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <metal_atomic> | ||||
| #include <metal_texture> | ||||
|  | ||||
| #include "mlx/backend/metal/kernels/bf16.h" | ||||
| #include "mlx/backend/metal/kernels/reduce.h" | ||||
| #include "mlx/backend/metal/kernels/utils.h" | ||||
|  | ||||
| using namespace metal; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Gather kernel | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename IdxT, int NIDX> | ||||
| struct Indices { | ||||
|   const array<device IdxT*, NIDX> buffers [[id(0)]]; | ||||
|   device int* shapes [[id(NIDX + 1)]]; | ||||
|   device size_t* strides [[id(NIDX + 2)]]; | ||||
|   const int ndim [[id(NIDX + 3)]]; | ||||
| }; | ||||
|  | ||||
| template <typename IdxT> | ||||
| inline size_t offset_neg_idx(IdxT idx, size_t size) { | ||||
|   return (idx < 0) ? idx + size : idx; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline size_t offset_neg_idx(bool idx, size_t) { | ||||
|   return idx; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline size_t offset_neg_idx(uint32_t idx, size_t) { | ||||
|   return idx; | ||||
| } | ||||
|  | ||||
| // IDX_NDIM is the number of dimensions of the indices arrays. Compile-time | ||||
| // special case for 0 and 1. Anything >= 2 uses the general case | ||||
| template <typename T, typename IdxT, int NIDX, int IDX_NDIM> | ||||
| [[kernel]] void gather( | ||||
|     const device T *src [[buffer(0)]], | ||||
|     const constant Indices<IdxT, NIDX>& indices [[buffer(1)]], | ||||
|     device T *out [[buffer(2)]], | ||||
|     const constant int *src_shape [[buffer(3)]], | ||||
|     const constant size_t *src_strides [[buffer(4)]], | ||||
|     const constant size_t& src_ndim [[buffer(5)]], | ||||
|     const constant int *slice_sizes [[buffer(6)]], | ||||
|     const constant int *axes [[buffer(7)]], | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|  | ||||
|   auto ind_idx = index.x; | ||||
|   auto ind_offset = index.y; | ||||
|  | ||||
|   size_t src_idx = 0; | ||||
|   for (int i = 0; i < NIDX; ++i) { | ||||
|     size_t idx_loc; | ||||
|     if (IDX_NDIM == 0) { | ||||
|       idx_loc = 0; | ||||
|     } else if (IDX_NDIM == 1) { | ||||
|       idx_loc = ind_idx * indices.strides[indices.ndim * i]; | ||||
|     } else { | ||||
|       idx_loc = elem_to_loc( | ||||
|           ind_idx, | ||||
|           &indices.shapes[indices.ndim * i], | ||||
|           &indices.strides[indices.ndim * i], | ||||
|           indices.ndim); | ||||
|     } | ||||
|     auto ax = axes[i]; | ||||
|     auto idx_val = offset_neg_idx( | ||||
|         indices.buffers[i][idx_loc], src_shape[ax]); | ||||
|     src_idx += idx_val * src_strides[ax]; | ||||
|   } | ||||
|  | ||||
|   auto src_offset = elem_to_loc( | ||||
|       ind_offset, slice_sizes, src_strides, src_ndim); | ||||
|  | ||||
|   size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x; | ||||
|   out[out_idx] = src[src_offset + src_idx]; | ||||
| } | ||||
|  | ||||
| #define instantiate_gather4(name, src_type, ind_type, nindex) \ | ||||
| template [[host_name("gather" name "_" #nindex "_0")]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 0>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); \ | ||||
| template [[host_name("gather" name "_" #nindex "_1")]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 1>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); \ | ||||
| template [[host_name("gather" name "_" #nindex)]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 2>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); | ||||
|  | ||||
|  | ||||
| // Special for case NIDX=0 | ||||
| instantiate_gather4("bool_", bool, bool, 0) | ||||
| instantiate_gather4("uint8", uint8_t, bool, 0) | ||||
| instantiate_gather4("uint16", uint16_t, bool, 0) | ||||
| instantiate_gather4("uint32", uint32_t, bool, 0) | ||||
| instantiate_gather4("uint64", uint64_t, bool, 0) | ||||
| instantiate_gather4("int8", int8_t, bool, 0) | ||||
| instantiate_gather4("int16", int16_t, bool, 0) | ||||
| instantiate_gather4("int32", int32_t, bool, 0) | ||||
| instantiate_gather4("int64", int64_t, bool, 0) | ||||
| instantiate_gather4("float16", half, bool, 0) | ||||
| instantiate_gather4("float32", float, bool, 0) | ||||
| instantiate_gather4("bfloat16", bfloat16_t, bool, 0) | ||||
|  | ||||
| #define instantiate_gather3(name, src_type, ind_type) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 1) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 2) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 3) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 4) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 5) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 6) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 7) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 8) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 9) \ | ||||
|   instantiate_gather4(name, src_type, ind_type, 10) | ||||
|  | ||||
| #define instantiate_gather(name, src_type) \ | ||||
|   instantiate_gather3(#name "bool_", src_type, bool) \ | ||||
|   instantiate_gather3(#name "uint8", src_type, uint8_t) \ | ||||
|   instantiate_gather3(#name "uint16", src_type, uint16_t) \ | ||||
|   instantiate_gather3(#name "uint32", src_type, uint32_t) \ | ||||
|   instantiate_gather3(#name "uint64", src_type, uint64_t) \ | ||||
|   instantiate_gather3(#name "int8", src_type, int8_t) \ | ||||
|   instantiate_gather3(#name "int16", src_type, int16_t) \ | ||||
|   instantiate_gather3(#name "int32", src_type, int32_t) \ | ||||
|   instantiate_gather3(#name "int64", src_type, int64_t) | ||||
|  | ||||
| instantiate_gather(bool_, bool) | ||||
| instantiate_gather(uint8, uint8_t) | ||||
| instantiate_gather(uint16, uint16_t) | ||||
| instantiate_gather(uint32, uint32_t) | ||||
| instantiate_gather(uint64, uint64_t) | ||||
| instantiate_gather(int8, int8_t) | ||||
| instantiate_gather(int16, int16_t) | ||||
| instantiate_gather(int32, int32_t) | ||||
| instantiate_gather(int64, int64_t) | ||||
| instantiate_gather(float16, half) | ||||
| instantiate_gather(float32, float) | ||||
| instantiate_gather(bfloat16, bfloat16_t) | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Scatter kernel | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| template <typename T, typename IdxT, typename Op, int NIDX> | ||||
| [[kernel]] void scatter( | ||||
|     const device Indices<IdxT, NIDX>& indices [[buffer(0)]], | ||||
|     const device T *updates [[buffer(1)]], | ||||
|     device mlx_atomic<T> *out [[buffer(2)]], | ||||
|     const device int *upd_shape [[buffer(3)]], | ||||
|     const device size_t *upd_strides [[buffer(4)]], | ||||
|     const device size_t& upd_ndim [[buffer(5)]], | ||||
|     const device size_t& upd_size [[buffer(6)]], | ||||
|     const device int *out_shape [[buffer(7)]], | ||||
|     const device size_t *out_strides [[buffer(8)]], | ||||
|     const device size_t& out_ndim [[buffer(9)]], | ||||
|     const device int* axes [[buffer(10)]], | ||||
|     uint2 gid [[thread_position_in_grid]]) { | ||||
|  | ||||
|   Op op; | ||||
|   auto ind_idx = gid.y; | ||||
|   auto ind_offset = gid.x; | ||||
|  | ||||
|   size_t out_idx = 0; | ||||
|   for (int i = 0; i < NIDX; ++i) { | ||||
|     auto idx_loc = elem_to_loc( | ||||
|         ind_idx, | ||||
|         &indices.shapes[indices.ndim * i], | ||||
|         &indices.strides[indices.ndim * i], | ||||
|         indices.ndim); | ||||
|     auto ax = axes[i]; | ||||
|     auto idx_val = offset_neg_idx( | ||||
|         indices.buffers[i][idx_loc], out_shape[ax]); | ||||
|     out_idx += idx_val * out_strides[ax]; | ||||
|   } | ||||
|  | ||||
|   auto out_offset = elem_to_loc( | ||||
|       ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); | ||||
|   auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); | ||||
|   op.atomic_update(out, updates[upd_idx], out_idx + out_offset); | ||||
| } | ||||
|  | ||||
| #define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ | ||||
| template [[host_name("scatter" name "_" #nindex)]] \ | ||||
| [[kernel]] void scatter<type, ind_type, op_type, nindex>( \ | ||||
|     const device Indices<ind_type, nindex>& indices [[buffer(0)]], \ | ||||
|     const device type *updates [[buffer(1)]], \ | ||||
|     device mlx_atomic<type> *out [[buffer(2)]], \ | ||||
|     const device int *upd_shape [[buffer(3)]], \ | ||||
|     const device size_t *upd_strides [[buffer(4)]], \ | ||||
|     const device size_t& upd_ndim [[buffer(5)]], \ | ||||
|     const device size_t& upd_size [[buffer(6)]], \ | ||||
|     const device int *out_shape [[buffer(7)]], \ | ||||
|     const device size_t *out_strides [[buffer(8)]], \ | ||||
|     const device size_t& out_ndim [[buffer(9)]], \ | ||||
|     const device int* axes [[buffer(10)]], \ | ||||
|     uint2 gid [[thread_position_in_grid]]); | ||||
|  | ||||
| // Special case NINDEX=0 | ||||
| #define instantiate_scatter_nd0(name, type) \ | ||||
|   instantiate_scatter4(#name "none", type, bool, None, 0) \ | ||||
|   instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) | ||||
|  | ||||
| #define instantiate_scatter3(name, type, ind_type, op_type) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 1) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 2) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 3) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 4) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 5) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 6) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 7) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 8) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 9) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 10) | ||||
|  | ||||
| #define instantiate_scatter2(name, type, ind_type) \ | ||||
|   instantiate_scatter3(name "_none", type, ind_type, None) \ | ||||
|   instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \ | ||||
|   instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \ | ||||
|   instantiate_scatter3(name "_max", type, ind_type, Max<type>) \ | ||||
|   instantiate_scatter3(name "_min", type, ind_type, Min<type>) | ||||
|  | ||||
| #define instantiate_scatter(name, type) \ | ||||
|   instantiate_scatter2(#name "bool_", type, bool) \ | ||||
|   instantiate_scatter2(#name "uint8", type, uint8_t) \ | ||||
|   instantiate_scatter2(#name "uint16", type, uint16_t) \ | ||||
|   instantiate_scatter2(#name "uint32", type, uint32_t) \ | ||||
|   instantiate_scatter2(#name "uint64", type, uint64_t) \ | ||||
|   instantiate_scatter2(#name "int8", type, int8_t) \ | ||||
|   instantiate_scatter2(#name "int16", type, int16_t) \ | ||||
|   instantiate_scatter2(#name "int32", type, int32_t) \ | ||||
|   instantiate_scatter2(#name "int64", type, int64_t) | ||||
|  | ||||
| // TODO uint64 and int64 unsupported | ||||
| instantiate_scatter_nd0(bool_, bool) | ||||
| instantiate_scatter_nd0(uint8, uint8_t) | ||||
| instantiate_scatter_nd0(uint16, uint16_t) | ||||
| instantiate_scatter_nd0(uint32, uint32_t) | ||||
| instantiate_scatter_nd0(int8, int8_t) | ||||
| instantiate_scatter_nd0(int16, int16_t) | ||||
| instantiate_scatter_nd0(int32, int32_t) | ||||
| instantiate_scatter_nd0(float16, half) | ||||
| instantiate_scatter_nd0(float32, float) | ||||
| instantiate_scatter_nd0(bfloat16, bfloat16_t) | ||||
|  | ||||
| instantiate_scatter(bool_, bool) | ||||
| instantiate_scatter(uint8, uint8_t) | ||||
| instantiate_scatter(uint16, uint16_t) | ||||
| instantiate_scatter(uint32, uint32_t) | ||||
| instantiate_scatter(int8, int8_t) | ||||
| instantiate_scatter(int16, int16_t) | ||||
| instantiate_scatter(int32, int32_t) | ||||
| instantiate_scatter(float16, half) | ||||
| instantiate_scatter(float32, float) | ||||
| instantiate_scatter(bfloat16, bfloat16_t) | ||||
							
								
								
									
										194
									
								
								mlx/backend/metal/kernels/scatter.metal
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								mlx/backend/metal/kernels/scatter.metal
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,194 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <metal_atomic> | ||||
|  | ||||
| #include "mlx/backend/metal/kernels/bf16.h" | ||||
| #include "mlx/backend/metal/kernels/indexing.h" | ||||
| #include "mlx/backend/metal/kernels/reduce.h" | ||||
| #include "mlx/backend/metal/kernels/utils.h" | ||||
|  | ||||
| using namespace metal; | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Scatter kernel | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
|  | ||||
| template <typename T, typename IdxT, typename Op, int NIDX> | ||||
| METAL_FUNC void scatter_impl( | ||||
|     const device T *updates [[buffer(1)]], | ||||
|     device mlx_atomic<T> *out [[buffer(2)]], | ||||
|     const constant int *upd_shape [[buffer(3)]], | ||||
|     const constant size_t *upd_strides [[buffer(4)]], | ||||
|     const constant size_t& upd_ndim [[buffer(5)]], | ||||
|     const constant size_t& upd_size [[buffer(6)]], | ||||
|     const constant int *out_shape [[buffer(7)]], | ||||
|     const constant size_t *out_strides [[buffer(8)]], | ||||
|     const constant size_t& out_ndim [[buffer(9)]], | ||||
|     const constant int* axes [[buffer(10)]], | ||||
|     const thread Indices<IdxT, NIDX>& indices, | ||||
|     uint2 gid [[thread_position_in_grid]]) { | ||||
|  | ||||
|   Op op; | ||||
|   auto ind_idx = gid.y; | ||||
|   auto ind_offset = gid.x; | ||||
|  | ||||
|   size_t out_idx = 0; | ||||
|   for (int i = 0; i < NIDX; ++i) { | ||||
|     auto idx_loc = elem_to_loc( | ||||
|         ind_idx, | ||||
|         &indices.shapes[indices.ndim * i], | ||||
|         &indices.strides[indices.ndim * i], | ||||
|         indices.ndim); | ||||
|     auto ax = axes[i]; | ||||
|     auto idx_val = offset_neg_idx( | ||||
|         indices.buffers[i][idx_loc], out_shape[ax]); | ||||
|     out_idx += idx_val * out_strides[ax]; | ||||
|   } | ||||
|  | ||||
|   auto out_offset = elem_to_loc( | ||||
|       ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); | ||||
|   auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); | ||||
|   op.atomic_update(out, updates[upd_idx], out_idx + out_offset); | ||||
| } | ||||
|  | ||||
| #define make_scatter_impl(IDX_ARG, IDX_ARR) \ | ||||
| template <typename T, typename IdxT, typename Op, int NIDX>  \ | ||||
| [[kernel]] void scatter( \ | ||||
|     const device T *updates [[buffer(1)]], \ | ||||
|     device mlx_atomic<T> *out [[buffer(2)]], \ | ||||
|     const constant int *upd_shape [[buffer(3)]], \ | ||||
|     const constant size_t *upd_strides [[buffer(4)]], \ | ||||
|     const constant size_t& upd_ndim [[buffer(5)]], \ | ||||
|     const constant size_t& upd_size [[buffer(6)]], \ | ||||
|     const constant int *out_shape [[buffer(7)]], \ | ||||
|     const constant size_t *out_strides [[buffer(8)]], \ | ||||
|     const constant size_t& out_ndim [[buffer(9)]], \ | ||||
|     const constant int* axes [[buffer(10)]], \ | ||||
|     const constant int *idx_shapes [[buffer(11)]], \ | ||||
|     const constant size_t *idx_strides [[buffer(12)]], \ | ||||
|     const constant int& idx_ndim [[buffer(13)]], \ | ||||
|     IDX_ARG(IdxT) \ | ||||
|     uint2 gid [[thread_position_in_grid]]) { \ | ||||
|  \ | ||||
|   Indices<IdxT, NIDX> idxs{ \ | ||||
|       {{IDX_ARR()}}, \ | ||||
|       idx_shapes, \ | ||||
|       idx_strides, \ | ||||
|       idx_ndim}; \ | ||||
|  \ | ||||
|   return scatter_impl<T, IdxT, Op, NIDX>( \ | ||||
|       updates, \ | ||||
|       out, \ | ||||
|       upd_shape, \ | ||||
|       upd_strides, \ | ||||
|       upd_ndim, \ | ||||
|       upd_size, \ | ||||
|       out_shape, \ | ||||
|       out_strides, \ | ||||
|       out_ndim, \ | ||||
|       axes, \ | ||||
|       idxs, \ | ||||
|       gid); \ | ||||
| }  | ||||
|  | ||||
| #define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) | ||||
|  | ||||
| make_scatter(0) | ||||
| make_scatter(1) | ||||
| make_scatter(2) | ||||
| make_scatter(3) | ||||
| make_scatter(4) | ||||
| make_scatter(5) | ||||
| make_scatter(6) | ||||
| make_scatter(7) | ||||
| make_scatter(8) | ||||
| make_scatter(9) | ||||
| make_scatter(10) | ||||
|  | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
| // Scatter instantiations | ||||
| ///////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| #define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ | ||||
| template [[host_name("scatter" name "_" #nidx)]] \ | ||||
| [[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \ | ||||
|     const device src_t *updates [[buffer(1)]], \ | ||||
|     device mlx_atomic<src_t> *out [[buffer(2)]], \ | ||||
|     const constant int *upd_shape [[buffer(3)]], \ | ||||
|     const constant size_t *upd_strides [[buffer(4)]], \ | ||||
|     const constant size_t& upd_ndim [[buffer(5)]], \ | ||||
|     const constant size_t& upd_size [[buffer(6)]], \ | ||||
|     const constant int *out_shape [[buffer(7)]], \ | ||||
|     const constant size_t *out_strides [[buffer(8)]], \ | ||||
|     const constant size_t& out_ndim [[buffer(9)]], \ | ||||
|     const constant int* axes [[buffer(10)]], \ | ||||
|     const constant int *idx_shapes [[buffer(11)]], \ | ||||
|     const constant size_t *idx_strides [[buffer(12)]], \ | ||||
|     const constant int& idx_ndim [[buffer(13)]], \ | ||||
|     IDX_ARG(idx_t) \ | ||||
|     uint2 gid [[thread_position_in_grid]]); | ||||
|  | ||||
| #define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ | ||||
|   instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) | ||||
|  | ||||
| // Special case NINDEX=0 | ||||
| #define instantiate_scatter_nd0(name, type) \ | ||||
|   instantiate_scatter4(#name "none", type, bool, None, 0) \ | ||||
|   instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \ | ||||
|   instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) | ||||
|  | ||||
| #define instantiate_scatter3(name, type, ind_type, op_type) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 1) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 2) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 3) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 4) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 5) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 6) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 7) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 8) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 9) \ | ||||
|   instantiate_scatter4(name, type, ind_type, op_type, 10) | ||||
|  | ||||
| #define instantiate_scatter2(name, type, ind_type) \ | ||||
|   instantiate_scatter3(name "_none", type, ind_type, None) \ | ||||
|   instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \ | ||||
|   instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \ | ||||
|   instantiate_scatter3(name "_max", type, ind_type, Max<type>) \ | ||||
|   instantiate_scatter3(name "_min", type, ind_type, Min<type>) | ||||
|  | ||||
| #define instantiate_scatter(name, type) \ | ||||
|   instantiate_scatter2(#name "bool_", type, bool) \ | ||||
|   instantiate_scatter2(#name "uint8", type, uint8_t) \ | ||||
|   instantiate_scatter2(#name "uint16", type, uint16_t) \ | ||||
|   instantiate_scatter2(#name "uint32", type, uint32_t) \ | ||||
|   instantiate_scatter2(#name "uint64", type, uint64_t) \ | ||||
|   instantiate_scatter2(#name "int8", type, int8_t) \ | ||||
|   instantiate_scatter2(#name "int16", type, int16_t) \ | ||||
|   instantiate_scatter2(#name "int32", type, int32_t) \ | ||||
|   instantiate_scatter2(#name "int64", type, int64_t) | ||||
|  | ||||
| // TODO uint64 and int64 unsupported | ||||
| instantiate_scatter_nd0(bool_, bool) | ||||
| instantiate_scatter_nd0(uint8, uint8_t) | ||||
| instantiate_scatter_nd0(uint16, uint16_t) | ||||
| instantiate_scatter_nd0(uint32, uint32_t) | ||||
| instantiate_scatter_nd0(int8, int8_t) | ||||
| instantiate_scatter_nd0(int16, int16_t) | ||||
| instantiate_scatter_nd0(int32, int32_t) | ||||
| instantiate_scatter_nd0(float16, half) | ||||
| instantiate_scatter_nd0(float32, float) | ||||
| instantiate_scatter_nd0(bfloat16, bfloat16_t) | ||||
|  | ||||
| instantiate_scatter(bool_, bool) | ||||
| instantiate_scatter(uint8, uint8_t) | ||||
| instantiate_scatter(uint16, uint16_t) | ||||
| instantiate_scatter(uint32, uint32_t) | ||||
| instantiate_scatter(int8, int8_t) | ||||
| instantiate_scatter(int16, int16_t) | ||||
| instantiate_scatter(int32, int32_t) | ||||
| instantiate_scatter(float16, half) | ||||
| instantiate_scatter(float32, float) | ||||
| instantiate_scatter(bfloat16, bfloat16_t) | ||||
| @@ -9,20 +9,6 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| void set_array_buffer( | ||||
|     MTL::ComputeCommandEncoder* compute_encoder, | ||||
|     MTL::ArgumentEncoder* enc, | ||||
|     const array& a, | ||||
|     int idx) { | ||||
|   auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr()); | ||||
|   auto offset = a.data<char>() - | ||||
|       static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents()); | ||||
|   enc->setBuffer(a_buf, offset, idx); | ||||
|   // MTL::Resource usage through argument buffer needs to be explicitly | ||||
|   // flagged to enable hazard tracking | ||||
|   compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); | ||||
| } | ||||
|  | ||||
| void set_array_buffer( | ||||
|     MTL::ComputeCommandEncoder* enc, | ||||
|     const array& a, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani