MLX
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename Op>
4[[kernel]] void ternary_v(
5 device const bool* a,
6 device const T* b,
7 device const T* c,
8 device T* d,
9 uint index [[thread_position_in_grid]]) {
10 d[index] = Op()(a[index], b[index], c[index]);
11}
12
13template <typename T, typename Op>
14[[kernel]] void ternary_v2(
15 device const bool* a,
16 device const T* b,
17 device const T* c,
18 device T* d,
19 uint2 index [[thread_position_in_grid]],
20 uint2 grid_dim [[threads_per_grid]]) {
21 size_t offset = index.x + grid_dim.x * size_t(index.y);
22 d[offset] = Op()(a[offset], b[offset], c[offset]);
23}
24
25template <typename T, typename Op>
26[[kernel]] void ternary_g_nd1(
27 device const bool* a,
28 device const T* b,
29 device const T* c,
30 device T* d,
31 constant const size_t& a_strides,
32 constant const size_t& b_strides,
33 constant const size_t& c_strides,
34 uint index [[thread_position_in_grid]]) {
35 auto a_idx = elem_to_loc_1(index, a_strides);
36 auto b_idx = elem_to_loc_1(index, b_strides);
37 auto c_idx = elem_to_loc_1(index, c_strides);
38 d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
39}
40
41template <typename T, typename Op>
42[[kernel]] void ternary_g_nd2(
43 device const bool* a,
44 device const T* b,
45 device const T* c,
46 device T* d,
47 constant const size_t a_strides[2],
48 constant const size_t b_strides[2],
49 constant const size_t c_strides[2],
50 uint2 index [[thread_position_in_grid]],
51 uint2 grid_dim [[threads_per_grid]]) {
52 auto a_idx = elem_to_loc_2(index, a_strides);
53 auto b_idx = elem_to_loc_2(index, b_strides);
54 auto c_idx = elem_to_loc_2(index, c_strides);
55 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
56 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
57}
58
59template <typename T, typename Op>
60[[kernel]] void ternary_g_nd3(
61 device const bool* a,
62 device const T* b,
63 device const T* c,
64 device T* d,
65 constant const size_t a_strides[3],
66 constant const size_t b_strides[3],
67 constant const size_t c_strides[3],
68 uint3 index [[thread_position_in_grid]],
69 uint3 grid_dim [[threads_per_grid]]) {
70 auto a_idx = elem_to_loc_3(index, a_strides);
71 auto b_idx = elem_to_loc_3(index, b_strides);
72 auto c_idx = elem_to_loc_3(index, c_strides);
73 size_t out_idx =
74 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
75 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
76}
77
78template <typename T, typename Op, int DIM>
79[[kernel]] void ternary_g_nd(
80 device const bool* a,
81 device const T* b,
82 device const T* c,
83 device T* d,
84 constant const int shape[DIM],
85 constant const size_t a_strides[DIM],
86 constant const size_t b_strides[DIM],
87 constant const size_t c_strides[DIM],
88 uint3 index [[thread_position_in_grid]],
89 uint3 grid_dim [[threads_per_grid]]) {
90 auto idx =
91 elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
92 size_t out_idx =
93 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
94 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
95}
96
97template <typename T, typename Op>
98[[kernel]] void ternary_g(
99 device const bool* a,
100 device const T* b,
101 device const T* c,
102 device T* d,
103 constant const int* shape,
104 constant const size_t* a_strides,
105 constant const size_t* b_strides,
106 constant const size_t* c_strides,
107 constant const int& ndim,
108 uint3 index [[thread_position_in_grid]],
109 uint3 grid_dim [[threads_per_grid]]) {
110 auto idx =
111 elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
112 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
113 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
114}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:161
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:173
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:258
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:167
void ternary_g_nd3(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t a_strides[3], constant const size_t b_strides[3], constant const size_t c_strides[3], uint3 index, uint3 grid_dim)
Definition ternary.h:60
void ternary_g_nd1(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t &a_strides, constant const size_t &b_strides, constant const size_t &c_strides, uint index)
Definition ternary.h:26
void ternary_v2(device const bool *a, device const T *b, device const T *c, device T *d, uint2 index, uint2 grid_dim)
Definition ternary.h:14
void ternary_v(device const bool *a, device const T *b, device const T *c, device T *d, uint index)
Definition ternary.h:4
void ternary_g_nd(device const bool *a, device const T *b, device const T *c, device T *d, constant const int shape[DIM], constant const size_t a_strides[DIM], constant const size_t b_strides[DIM], constant const size_t c_strides[DIM], uint3 index, uint3 grid_dim)
Definition ternary.h:79
void ternary_g(device const bool *a, device const T *b, device const T *c, device T *d, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition ternary.h:98
void ternary_g_nd2(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t a_strides[2], constant const size_t b_strides[2], constant const size_t c_strides[2], uint2 index, uint2 grid_dim)
Definition ternary.h:42