mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	| @@ -1,4 +1,4 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
| #include <algorithm> | ||||
| #include <cassert> | ||||
| #include <numeric> | ||||
| @@ -39,9 +39,15 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto& s = stream(); | ||||
|   auto& d = metal::device(s.device); | ||||
|  | ||||
|   int idx_ndim = nidx ? inputs[1].ndim() : 0; | ||||
|   size_t ndim = src.ndim(); | ||||
|  | ||||
|   std::ostringstream kname; | ||||
|   std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; | ||||
|   kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; | ||||
|   if (idx_ndim <= 1) { | ||||
|     kname << "_" << idx_ndim; | ||||
|   } | ||||
|  | ||||
|   auto compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname.str()); | ||||
| @@ -51,15 +57,11 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     slice_size *= s; | ||||
|   } | ||||
|  | ||||
|   size_t ndim = src.ndim(); | ||||
|   size_t nthreads = out.size(); | ||||
|   NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|   if (thread_group_size > nthreads) { | ||||
|     thread_group_size = nthreads; | ||||
|   } | ||||
|  | ||||
|   MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|   MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|   // Launch 2D grid of threads: indices x slice | ||||
|   size_t dim0 = out.size() / slice_size; | ||||
|   size_t dim1 = slice_size; | ||||
|   auto group_dims = get_block_dims(dim0, dim1, 1); | ||||
|   MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); | ||||
|  | ||||
|   compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
| @@ -90,7 +92,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   auto arg_enc = d.argument_encoder(arg_descs); | ||||
|  | ||||
|   // Allocate and fill buffers for shapes and strides | ||||
|   int idx_ndim = nidx ? inputs[1].ndim() : 0; | ||||
|   auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); | ||||
|   auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); | ||||
|   for (int i = 0; i < nidx; ++i) { | ||||
| @@ -130,12 +131,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   set_array_buffer(compute_encoder, src, 0); | ||||
|   compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1); | ||||
|   set_array_buffer(compute_encoder, out, 2); | ||||
|  | ||||
|   compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); | ||||
|   compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); | ||||
|   compute_encoder->setBytes(&ndim, sizeof(size_t), 5); | ||||
|   compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); | ||||
|   compute_encoder->setBytes(&slice_size, sizeof(size_t), 7); | ||||
|   compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8); | ||||
|   compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <metal_atomic> | ||||
| #include <metal_texture> | ||||
| @@ -36,29 +36,38 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) { | ||||
|   return idx; | ||||
| } | ||||
|  | ||||
| template <typename T, typename IdxT, int NIDX> | ||||
| // IDX_NDIM is the number of dimensions of the indices arrays. Compile-time | ||||
| // special case for 0 and 1. Anything >= 2 uses the general case | ||||
| template <typename T, typename IdxT, int NIDX, int IDX_NDIM> | ||||
| [[kernel]] void gather( | ||||
|     const device T *src [[buffer(0)]], | ||||
|     const device Indices<IdxT, NIDX>& indices [[buffer(1)]], | ||||
|     const constant Indices<IdxT, NIDX>& indices [[buffer(1)]], | ||||
|     device T *out [[buffer(2)]], | ||||
|     const device int *src_shape [[buffer(3)]], | ||||
|     const device size_t *src_strides [[buffer(4)]], | ||||
|     const device size_t& src_ndim [[buffer(5)]], | ||||
|     const device int *slice_sizes [[buffer(6)]], | ||||
|     const device size_t& slice_size [[buffer(7)]], | ||||
|     const device int *axes [[buffer(8)]], | ||||
|     uint gid [[thread_position_in_grid]]) { | ||||
|     const constant int *src_shape [[buffer(3)]], | ||||
|     const constant size_t *src_strides [[buffer(4)]], | ||||
|     const constant size_t& src_ndim [[buffer(5)]], | ||||
|     const constant int *slice_sizes [[buffer(6)]], | ||||
|     const constant int *axes [[buffer(7)]], | ||||
|     uint2 index [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]]) { | ||||
|  | ||||
|   auto ind_idx = gid / slice_size; | ||||
|   auto ind_offset = gid % slice_size; | ||||
|   auto ind_idx = index.x; | ||||
|   auto ind_offset = index.y; | ||||
|  | ||||
|   size_t src_idx = 0; | ||||
|   for (int i = 0; i < NIDX; ++i) { | ||||
|     auto idx_loc = elem_to_loc( | ||||
|         ind_idx, | ||||
|         &indices.shapes[indices.ndim * i], | ||||
|         &indices.strides[indices.ndim * i], | ||||
|         indices.ndim); | ||||
|     size_t idx_loc; | ||||
|     if (IDX_NDIM == 0) { | ||||
|       idx_loc = 0; | ||||
|     } else if (IDX_NDIM == 1) { | ||||
|       idx_loc = ind_idx * indices.strides[indices.ndim * i]; | ||||
|     } else { | ||||
|       idx_loc = elem_to_loc( | ||||
|           ind_idx, | ||||
|           &indices.shapes[indices.ndim * i], | ||||
|           &indices.strides[indices.ndim * i], | ||||
|           indices.ndim); | ||||
|     } | ||||
|     auto ax = axes[i]; | ||||
|     auto idx_val = offset_neg_idx( | ||||
|         indices.buffers[i][idx_loc], src_shape[ax]); | ||||
| @@ -67,22 +76,49 @@ template <typename T, typename IdxT, int NIDX> | ||||
|  | ||||
|   auto src_offset = elem_to_loc( | ||||
|       ind_offset, slice_sizes, src_strides, src_ndim); | ||||
|   out[gid] = src[src_idx + src_offset]; | ||||
|  | ||||
|   size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x; | ||||
|   out[out_idx] = src[src_offset + src_idx]; | ||||
| } | ||||
|  | ||||
| #define instantiate_gather4(name, src_type, ind_type, nindex) \ | ||||
| template [[host_name("gather" name "_" #nindex)]] \ | ||||
| [[kernel]] void gather( \ | ||||
| template [[host_name("gather" name "_" #nindex "_0")]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 0>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const device Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const device int *src_shape [[buffer(3)]], \ | ||||
|     const device size_t *src_strides [[buffer(4)]], \ | ||||
|     const device size_t& src_ndim [[buffer(5)]], \ | ||||
|     const device int *slice_sizes [[buffer(6)]], \ | ||||
|     const device size_t& slice_size [[buffer(7)]], \ | ||||
|     const device int* axes [[buffer(8)]], \ | ||||
|     uint gid [[thread_position_in_grid]]); | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); \ | ||||
| template [[host_name("gather" name "_" #nindex "_1")]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 1>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); \ | ||||
| template [[host_name("gather" name "_" #nindex)]] \ | ||||
| [[kernel]] void gather<src_type, ind_type, nindex, 2>( \ | ||||
|     const device src_type *src [[buffer(0)]], \ | ||||
|     const constant Indices<ind_type, nindex>& indices [[buffer(1)]], \ | ||||
|     device src_type *out [[buffer(2)]], \ | ||||
|     const constant int *src_shape [[buffer(3)]], \ | ||||
|     const constant size_t *src_strides [[buffer(4)]], \ | ||||
|     const constant size_t& src_ndim [[buffer(5)]], \ | ||||
|     const constant int *slice_sizes [[buffer(6)]], \ | ||||
|     const constant int* axes [[buffer(7)]], \ | ||||
|     uint2 index [[thread_position_in_grid]], \ | ||||
|     uint2 grid_dim [[threads_per_grid]]); | ||||
|  | ||||
|  | ||||
| // Special for case NIDX=0 | ||||
| instantiate_gather4("bool_", bool, bool, 0) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun