2024-05-23 03:57:13 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
2024-10-16 07:23:15 +08:00
|
|
|
template <typename T, typename U, typename Op>
|
2024-05-23 03:57:13 +08:00
|
|
|
[[kernel]] void unary_v(
|
|
|
|
device const T* in,
|
2024-10-16 07:23:15 +08:00
|
|
|
device U* out,
|
2024-05-23 03:57:13 +08:00
|
|
|
uint index [[thread_position_in_grid]]) {
|
|
|
|
out[index] = Op()(in[index]);
|
2024-02-26 00:39:55 +08:00
|
|
|
}
|
|
|
|
|
2024-10-16 07:23:15 +08:00
|
|
|
template <typename T, typename U, typename Op>
|
2024-07-31 08:18:39 +08:00
|
|
|
[[kernel]] void unary_v2(
|
|
|
|
device const T* in,
|
2024-10-16 07:23:15 +08:00
|
|
|
device U* out,
|
2024-07-31 08:18:39 +08:00
|
|
|
uint2 index [[thread_position_in_grid]],
|
|
|
|
uint2 grid_dim [[threads_per_grid]]) {
|
2024-12-10 03:09:02 +08:00
|
|
|
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
2024-07-31 08:18:39 +08:00
|
|
|
out[offset] = Op()(in[offset]);
|
|
|
|
}
|
|
|
|
|
2024-11-19 11:52:00 +08:00
|
|
|
template <
|
|
|
|
typename T,
|
|
|
|
typename U,
|
|
|
|
typename Op,
|
|
|
|
int N = 1,
|
2024-12-10 03:09:02 +08:00
|
|
|
typename IdxT = int64_t>
|
2024-05-23 03:57:13 +08:00
|
|
|
[[kernel]] void unary_g(
|
|
|
|
device const T* in,
|
2024-10-16 07:23:15 +08:00
|
|
|
device U* out,
|
2024-09-18 03:46:31 +08:00
|
|
|
constant const int* in_shape,
|
2024-12-10 03:09:02 +08:00
|
|
|
constant const int64_t* in_strides,
|
2024-05-23 03:57:13 +08:00
|
|
|
device const int& ndim,
|
2024-09-26 03:07:43 +08:00
|
|
|
uint3 index [[thread_position_in_grid]],
|
|
|
|
uint3 grid_dim [[threads_per_grid]]) {
|
2024-12-10 03:09:02 +08:00
|
|
|
auto idx = elem_to_loc<IdxT>(
|
2024-11-19 11:52:00 +08:00
|
|
|
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
2024-09-26 03:07:43 +08:00
|
|
|
auto xshape = in_shape[ndim - 1];
|
2024-11-19 11:52:00 +08:00
|
|
|
IdxT xstride = in_strides[ndim - 1];
|
|
|
|
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
2024-09-26 03:07:43 +08:00
|
|
|
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
|
|
out[out_idx++] = Op()(in[idx]);
|
|
|
|
idx += xstride;
|
|
|
|
}
|
2024-05-23 03:57:13 +08:00
|
|
|
}
|