MLX
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename Op>
4[[kernel]] void unary_v(
5 device const T* in,
6 device T* out,
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
9}
10
11template <typename T, typename Op>
12[[kernel]] void unary_g(
13 device const T* in,
14 device T* out,
15 device const int* in_shape,
16 device const size_t* in_strides,
17 device const int& ndim,
18 uint index [[thread_position_in_grid]]) {
19 auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
20 out[index] = Op()(in[idx]);
21}
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
void unary_g(device const T *in, device T *out, device const int *in_shape, device const size_t *in_strides, device const int &ndim, uint index)
Definition unary.h:12
void unary_v(device const T *in, device T *out, uint index)
Definition unary.h:4