MLX
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U, typename Op>
4[[kernel]] void unary_v(
5 device const T* in,
6 device U* out,
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
9}
10
11template <typename T, typename U, typename Op>
12[[kernel]] void unary_v2(
13 device const T* in,
14 device U* out,
15 uint2 index [[thread_position_in_grid]],
16 uint2 grid_dim [[threads_per_grid]]) {
17 size_t offset = index.x + grid_dim.x * size_t(index.y);
18 out[offset] = Op()(in[offset]);
19}
20
21template <typename T, typename U, typename Op, int N = 1>
22[[kernel]] void unary_g(
23 device const T* in,
24 device U* out,
25 constant const int* in_shape,
26 constant const size_t* in_strides,
27 device const int& ndim,
28 uint3 index [[thread_position_in_grid]],
29 uint3 grid_dim [[threads_per_grid]]) {
30 auto idx =
31 elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
32 auto xshape = in_shape[ndim - 1];
33 auto xstride = in_strides[ndim - 1];
34 size_t out_idx =
35 N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
36 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
37 out[out_idx++] = Op()(in[idx]);
38 idx += xstride;
39 }
40}
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
void unary_v(device const T *in, device U *out, uint index)
Definition unary.h:4
void unary_v2(device const T *in, device U *out, uint2 index, uint2 grid_dim)
Definition unary.h:12
void unary_g(device const T *in, device U *out, constant const int *in_shape, constant const size_t *in_strides, device const int &ndim, uint3 index, uint3 grid_dim)
Definition unary.h:22