MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U, typename Op>
4[[kernel]] void binary_ss(
5 device const T* a,
6 device const T* b,
7 device U* c,
8 uint index [[thread_position_in_grid]]) {
9 c[index] = Op()(a[0], b[0]);
10}
11
12template <typename T, typename U, typename Op>
13[[kernel]] void binary_sv(
14 device const T* a,
15 device const T* b,
16 device U* c,
17 uint index [[thread_position_in_grid]]) {
18 c[index] = Op()(a[0], b[index]);
19}
20
21template <typename T, typename U, typename Op>
22[[kernel]] void binary_vs(
23 device const T* a,
24 device const T* b,
25 device U* c,
26 uint index [[thread_position_in_grid]]) {
27 c[index] = Op()(a[index], b[0]);
28}
29
30template <typename T, typename U, typename Op>
31[[kernel]] void binary_vv(
32 device const T* a,
33 device const T* b,
34 device U* c,
35 uint index [[thread_position_in_grid]]) {
36 c[index] = Op()(a[index], b[index]);
37}
38
39template <typename T, typename U, typename Op>
40[[kernel]] void binary_sv2(
41 device const T* a,
42 device const T* b,
43 device U* c,
44 uint2 index [[thread_position_in_grid]],
45 uint2 grid_dim [[threads_per_grid]]) {
46 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
47 c[offset] = Op()(a[0], b[offset]);
48}
49
50template <typename T, typename U, typename Op>
51[[kernel]] void binary_vs2(
52 device const T* a,
53 device const T* b,
54 device U* c,
55 uint2 index [[thread_position_in_grid]],
56 uint2 grid_dim [[threads_per_grid]]) {
57 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
58 c[offset] = Op()(a[offset], b[0]);
59}
60
61template <typename T, typename U, typename Op>
62[[kernel]] void binary_vv2(
63 device const T* a,
64 device const T* b,
65 device U* c,
66 uint2 index [[thread_position_in_grid]],
67 uint2 grid_dim [[threads_per_grid]]) {
68 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
69 c[offset] = Op()(a[offset], b[offset]);
70}
71
72template <typename T, typename U, typename Op, typename IdxT = int64_t>
73[[kernel]] void binary_g_nd1(
74 device const T* a,
75 device const T* b,
76 device U* c,
77 constant const int64_t& a_stride,
78 constant const int64_t& b_stride,
79 uint index [[thread_position_in_grid]]) {
80 auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
81 auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
82 c[index] = Op()(a[a_idx], b[b_idx]);
83}
84
85template <typename T, typename U, typename Op, typename IdxT = int64_t>
86[[kernel]] void binary_g_nd2(
87 device const T* a,
88 device const T* b,
89 device U* c,
90 constant const int64_t a_strides[2],
91 constant const int64_t b_strides[2],
92 uint2 index [[thread_position_in_grid]],
93 uint2 grid_dim [[threads_per_grid]]) {
94 auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
95 auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
96 IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
97 c[out_idx] = Op()(a[a_idx], b[b_idx]);
98}
99
100template <typename T, typename U, typename Op, typename IdxT = int64_t>
101[[kernel]] void binary_g_nd3(
102 device const T* a,
103 device const T* b,
104 device U* c,
105 constant const int64_t a_strides[3],
106 constant const int64_t b_strides[3],
107 uint3 index [[thread_position_in_grid]],
108 uint3 grid_dim [[threads_per_grid]]) {
109 auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
110 auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
111 IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
112 c[out_idx] = Op()(a[a_idx], b[b_idx]);
113}
114
115template <
116 typename T,
117 typename U,
118 typename Op,
119 int N = 1,
120 typename IdxT = int64_t>
121[[kernel]] void binary_g(
122 device const T* a,
123 device const T* b,
124 device U* c,
125 constant const int* shape,
126 constant const int64_t* a_strides,
127 constant const int64_t* b_strides,
128 constant const int& ndim,
129 uint3 index [[thread_position_in_grid]],
130 uint3 grid_dim [[threads_per_grid]]) {
131 auto idx = elem_to_loc_2_nd<IdxT>(
132 {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
133 auto xshape = shape[ndim - 1];
134 IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
135 IdxT a_xstride = a_strides[ndim - 1];
136 IdxT b_xstride = b_strides[ndim - 1];
137 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
138 c[out_idx++] = Op()(a[idx.x], b[idx.y]);
139 idx.x += a_xstride;
140 idx.y += b_xstride;
141 }
142}
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t &stride)
Definition utils.h:126
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:145
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2])
Definition utils.h:131
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3])
Definition utils.h:136
void binary_vv2(device const T *a, device const T *b, device U *c, uint2 index, uint2 grid_dim)
Definition binary.h:62
void binary_ss(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:4
void binary_sv(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:13
void binary_vs2(device const T *a, device const T *b, device U *c, uint2 index, uint2 grid_dim)
Definition binary.h:51
void binary_vs(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:22
void binary_g_nd2(device const T *a, device const T *b, device U *c, constant const int64_t a_strides[2], constant const int64_t b_strides[2], uint2 index, uint2 grid_dim)
Definition binary.h:86
void binary_sv2(device const T *a, device const T *b, device U *c, uint2 index, uint2 grid_dim)
Definition binary.h:40
void binary_g(device const T *a, device const T *b, device U *c, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition binary.h:121
void binary_g_nd3(device const T *a, device const T *b, device U *c, constant const int64_t a_strides[3], constant const int64_t b_strides[3], uint3 index, uint3 grid_dim)
Definition binary.h:101
void binary_g_nd1(device const T *a, device const T *b, device U *c, constant const int64_t &a_stride, constant const int64_t &b_stride, uint index)
Definition binary.h:73
void binary_vv(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:31
constexpr int N
Definition neon_fp16_simd.h:9