MLX
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename Op>
4[[kernel]] void ternary_v(
5 device const bool* a,
6 device const T* b,
7 device const T* c,
8 device T* d,
9 uint index [[thread_position_in_grid]]) {
10 d[index] = Op()(a[index], b[index], c[index]);
11}
12
13template <typename T, typename Op>
14[[kernel]] void ternary_g_nd1(
15 device const bool* a,
16 device const T* b,
17 device const T* c,
18 device T* d,
19 constant const size_t& a_strides,
20 constant const size_t& b_strides,
21 constant const size_t& c_strides,
22 uint index [[thread_position_in_grid]]) {
23 auto a_idx = elem_to_loc_1(index, a_strides);
24 auto b_idx = elem_to_loc_1(index, b_strides);
25 auto c_idx = elem_to_loc_1(index, c_strides);
26 d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
27}
28
29template <typename T, typename Op>
30[[kernel]] void ternary_g_nd2(
31 device const bool* a,
32 device const T* b,
33 device const T* c,
34 device T* d,
35 constant const size_t a_strides[2],
36 constant const size_t b_strides[2],
37 constant const size_t c_strides[2],
38 uint2 index [[thread_position_in_grid]],
39 uint2 grid_dim [[threads_per_grid]]) {
40 auto a_idx = elem_to_loc_2(index, a_strides);
41 auto b_idx = elem_to_loc_2(index, b_strides);
42 auto c_idx = elem_to_loc_2(index, c_strides);
43 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
44 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
45}
46
47template <typename T, typename Op>
48[[kernel]] void ternary_g_nd3(
49 device const bool* a,
50 device const T* b,
51 device const T* c,
52 device T* d,
53 constant const size_t a_strides[3],
54 constant const size_t b_strides[3],
55 constant const size_t c_strides[3],
56 uint3 index [[thread_position_in_grid]],
57 uint3 grid_dim [[threads_per_grid]]) {
58 auto a_idx = elem_to_loc_3(index, a_strides);
59 auto b_idx = elem_to_loc_3(index, b_strides);
60 auto c_idx = elem_to_loc_3(index, c_strides);
61 size_t out_idx =
62 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
63 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
64}
65
66template <typename T, typename Op, int DIM>
67[[kernel]] void ternary_g_nd(
68 device const bool* a,
69 device const T* b,
70 device const T* c,
71 device T* d,
72 constant const int shape[DIM],
73 constant const size_t a_strides[DIM],
74 constant const size_t b_strides[DIM],
75 constant const size_t c_strides[DIM],
76 uint3 index [[thread_position_in_grid]],
77 uint3 grid_dim [[threads_per_grid]]) {
78 auto idx =
79 elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
80 size_t out_idx =
81 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
82 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
83}
84
85template <typename T, typename Op>
86[[kernel]] void ternary_g(
87 device const bool* a,
88 device const T* b,
89 device const T* c,
90 device T* d,
91 constant const int* shape,
92 constant const size_t* a_strides,
93 constant const size_t* b_strides,
94 constant const size_t* c_strides,
95 constant const int& ndim,
96 uint3 index [[thread_position_in_grid]],
97 uint3 grid_dim [[threads_per_grid]]) {
98 auto idx =
99 elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
100 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
101 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
102}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:123
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:135
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:220
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:129
void ternary_g_nd3(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t a_strides[3], constant const size_t b_strides[3], constant const size_t c_strides[3], uint3 index, uint3 grid_dim)
Definition ternary.h:48
void ternary_g_nd1(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t &a_strides, constant const size_t &b_strides, constant const size_t &c_strides, uint index)
Definition ternary.h:14
void ternary_v(device const bool *a, device const T *b, device const T *c, device T *d, uint index)
Definition ternary.h:4
void ternary_g_nd(device const bool *a, device const T *b, device const T *c, device T *d, constant const int shape[DIM], constant const size_t a_strides[DIM], constant const size_t b_strides[DIM], constant const size_t c_strides[DIM], uint3 index, uint3 grid_dim)
Definition ternary.h:67
void ternary_g(device const bool *a, device const T *b, device const T *c, device T *d, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition ternary.h:86
void ternary_g_nd2(device const bool *a, device const T *b, device const T *c, device T *d, constant const size_t a_strides[2], constant const size_t b_strides[2], constant const size_t c_strides[2], uint2 index, uint2 grid_dim)
Definition ternary.h:30