MLX
Loading...
Searching...
No Matches
gather.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
7template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
8METAL_FUNC void gather_impl(
9 const device T* src [[buffer(0)]],
10 device T* out [[buffer(1)]],
11 const constant int* src_shape [[buffer(2)]],
12 const constant size_t* src_strides [[buffer(3)]],
13 const constant size_t& src_ndim [[buffer(4)]],
14 const constant int* slice_sizes [[buffer(5)]],
15 const constant int* axes [[buffer(6)]],
16 const thread Indices<IdxT, NIDX>& indices,
17 uint3 index [[thread_position_in_grid]],
18 uint3 grid_dim [[threads_per_grid]]) {
19 size_t src_idx = 0;
20 for (int i = 0; i < NIDX; ++i) {
21 size_t idx_loc;
22 if (IDX_NDIM == 0) {
23 idx_loc = 0;
24 } else if (IDX_NDIM == 1) {
25 idx_loc = index.x * indices.strides[indices.ndim * i];
26 } else {
27 idx_loc = index.x * indices.strides[indices.ndim * i];
28 idx_loc += elem_to_loc(
29 index.y,
30 &indices.shapes[indices.ndim * i + 1],
31 &indices.strides[indices.ndim * i + 1],
32 indices.ndim - 1);
33 }
34 auto ax = axes[i];
35 auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
36 src_idx += idx_val * src_strides[ax];
37 }
38
39 auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
40
41 size_t out_idx = index.z;
42 if (IDX_NDIM == 1) {
43 out_idx += static_cast<size_t>(grid_dim.z) * index.x;
44 } else if (IDX_NDIM >= 2) {
45 out_idx +=
46 grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
47 }
48 out[out_idx] = src[src_offset + src_idx];
49}
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
METAL_FUNC void gather_impl(const device T *src, device T *out, const constant int *src_shape, const constant size_t *src_strides, const constant size_t &src_ndim, const constant int *slice_sizes, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint3 index, uint3 grid_dim)
Definition gather.h:8
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size)
Definition indexing.h:16
Definition indexing.h:8