MLX
Loading...
Searching...
No Matches
indexing.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_stdlib>
4
5using namespace metal;
6
8// Indexing utils
10
11template <typename IdxT, int NIDX>
12struct Indices {
13 const array<const device IdxT*, NIDX> buffers;
14 const constant int* shapes;
15 const constant size_t* strides;
16 const int ndim;
17};
18
19template <typename IdxT>
20METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
21 if (is_unsigned_v<IdxT>) {
22 return idx;
23 } else {
24 return (idx < 0) ? idx + size : idx;
25 }
26}
27
28#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
29
30#define IDX_ARG_0(idx_t)
31#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
32#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
33#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
34#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
35#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
36#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
37#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
38#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
39#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
40#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
41
42#define IDX_ARR_N(n) idx##n,
43
44#define IDX_ARR_0()
45#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
46#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
47#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
48#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
49#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
50#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
51#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
52#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
53#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
54#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size)
Definition indexing.h:20
Definition bf16.h:265
Definition indexing.h:12
const constant int * shapes
Definition indexing.h:14
const int ndim
Definition indexing.h:16
const constant size_t * strides
Definition indexing.h:15
const array< const device IdxT *, NIDX > buffers
Definition indexing.h:13