MLX
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U, typename Op>
4[[kernel]] void binary_ss(
5 device const T* a,
6 device const T* b,
7 device U* c,
8 uint index [[thread_position_in_grid]]) {
9 c[index] = Op()(a[0], b[0]);
10}
11
12template <typename T, typename U, typename Op>
13[[kernel]] void binary_sv(
14 device const T* a,
15 device const T* b,
16 device U* c,
17 uint index [[thread_position_in_grid]]) {
18 c[index] = Op()(a[0], b[index]);
19}
20
21template <typename T, typename U, typename Op>
22[[kernel]] void binary_vs(
23 device const T* a,
24 device const T* b,
25 device U* c,
26 uint index [[thread_position_in_grid]]) {
27 c[index] = Op()(a[index], b[0]);
28}
29
30template <typename T, typename U, typename Op>
31[[kernel]] void binary_vv(
32 device const T* a,
33 device const T* b,
34 device U* c,
35 uint index [[thread_position_in_grid]]) {
36 c[index] = Op()(a[index], b[index]);
37}
38
39template <typename T, typename U, typename Op>
40[[kernel]] void binary_g_nd1(
41 device const T* a,
42 device const T* b,
43 device U* c,
44 constant const size_t& a_stride,
45 constant const size_t& b_stride,
46 uint index [[thread_position_in_grid]]) {
47 auto a_idx = elem_to_loc_1(index, a_stride);
48 auto b_idx = elem_to_loc_1(index, b_stride);
49 c[index] = Op()(a[a_idx], b[b_idx]);
50}
51
52template <typename T, typename U, typename Op>
53[[kernel]] void binary_g_nd2(
54 device const T* a,
55 device const T* b,
56 device U* c,
57 constant const size_t a_strides[2],
58 constant const size_t b_strides[2],
59 uint2 index [[thread_position_in_grid]],
60 uint2 grid_dim [[threads_per_grid]]) {
61 auto a_idx = elem_to_loc_2(index, a_strides);
62 auto b_idx = elem_to_loc_2(index, b_strides);
63 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
64 c[out_idx] = Op()(a[a_idx], b[b_idx]);
65}
66
67template <typename T, typename U, typename Op>
68[[kernel]] void binary_g_nd3(
69 device const T* a,
70 device const T* b,
71 device U* c,
72 constant const size_t a_strides[3],
73 constant const size_t b_strides[3],
74 uint3 index [[thread_position_in_grid]],
75 uint3 grid_dim [[threads_per_grid]]) {
76 auto a_idx = elem_to_loc_3(index, a_strides);
77 auto b_idx = elem_to_loc_3(index, b_strides);
78 size_t out_idx =
79 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
80 c[out_idx] = Op()(a[a_idx], b[b_idx]);
81}
82
83template <typename T, typename U, typename Op, int DIM>
84[[kernel]] void binary_g_nd(
85 device const T* a,
86 device const T* b,
87 device U* c,
88 constant const int shape[DIM],
89 constant const size_t a_strides[DIM],
90 constant const size_t b_strides[DIM],
91 uint3 index [[thread_position_in_grid]],
92 uint3 grid_dim [[threads_per_grid]]) {
93 auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
94 size_t out_idx =
95 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
96 c[out_idx] = Op()(a[idx.x], b[idx.y]);
97}
98
99template <typename T, typename U, typename Op>
100[[kernel]] void binary_g(
101 device const T* a,
102 device const T* b,
103 device U* c,
104 constant const int* shape,
105 constant const size_t* a_strides,
106 constant const size_t* b_strides,
107 constant const int& ndim,
108 uint3 index [[thread_position_in_grid]],
109 uint3 grid_dim [[threads_per_grid]]) {
110 auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
111 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
112 c[out_idx] = Op()(a[idx.x], b[idx.y]);
113}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:123
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:135
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:200
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:129
void binary_ss(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:4
void binary_sv(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:13
void binary_g_nd(device const T *a, device const T *b, device U *c, constant const int shape[DIM], constant const size_t a_strides[DIM], constant const size_t b_strides[DIM], uint3 index, uint3 grid_dim)
Definition binary.h:84
void binary_vs(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:22
void binary_g_nd1(device const T *a, device const T *b, device U *c, constant const size_t &a_stride, constant const size_t &b_stride, uint index)
Definition binary.h:40
void binary_g(device const T *a, device const T *b, device U *c, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition binary.h:100
void binary_g_nd2(device const T *a, device const T *b, device U *c, constant const size_t a_strides[2], constant const size_t b_strides[2], uint2 index, uint2 grid_dim)
Definition binary.h:53
void binary_g_nd3(device const T *a, device const T *b, device U *c, constant const size_t a_strides[3], constant const size_t b_strides[3], uint3 index, uint3 grid_dim)
Definition binary.h:68
void binary_vv(device const T *a, device const T *b, device U *c, uint index)
Definition binary.h:31