MLX
 
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U, typename Op>
4[[kernel]] void unary_v(
5 device const T* in,
6 device U* out,
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
9}
10
11template <typename T, typename U, typename Op>
12[[kernel]] void unary_v2(
13 device const T* in,
14 device U* out,
15 uint2 index [[thread_position_in_grid]],
16 uint2 grid_dim [[threads_per_grid]]) {
17 auto offset = index.x + grid_dim.x * int64_t(index.y);
18 out[offset] = Op()(in[offset]);
19}
20
21template <
22 typename T,
23 typename U,
24 typename Op,
25 int N = 1,
26 typename IdxT = int64_t>
27[[kernel]] void unary_g(
28 device const T* in,
29 device U* out,
30 constant const int* in_shape,
31 constant const int64_t* in_strides,
32 device const int& ndim,
33 uint3 index [[thread_position_in_grid]],
34 uint3 grid_dim [[threads_per_grid]]) {
35 auto idx = elem_to_loc<IdxT>(
36 {N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
37 auto xshape = in_shape[ndim - 1];
38 IdxT xstride = in_strides[ndim - 1];
39 IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
40 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
41 out[out_idx++] = Op()(in[idx]);
42 idx += xstride;
43 }
44}
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
void unary_v(device const T *in, device U *out, uint index)
Definition unary.h:4
void unary_v2(device const T *in, device U *out, uint2 index, uint2 grid_dim)
Definition unary.h:12
void unary_g(device const T *in, device U *out, constant const int *in_shape, constant const int64_t *in_strides, device const int &ndim, uint3 index, uint3 grid_dim)
Definition unary.h:27
constexpr int N
Definition neon_fp16_simd.h:9