mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 20:28:16 +08:00
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile * unary, binary, ternary in jit * nits * fix gather/scatter * fix rebase * reorg compile * add ternary to compile * jit copy * jit compile flag * fix build * use linked function for ternary * some nits * docs + circle min size build * docs + circle min size build * fix extension * fix no cpu build * improve includes
This commit is contained in:
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
||||
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
||||
100
mlx/backend/metal/jit/copy.h
Normal file
100
mlx/backend/metal/jit/copy.h
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view copy_kernels = R"(
|
||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||
copy_gg_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||
copy_gg_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||
copy_gg_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
)";
|
||||
21
mlx/backend/metal/jit/includes.h
Normal file
21
mlx/backend/metal/jit/includes.h
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* utils();
|
||||
const char* binary_ops();
|
||||
const char* unary_ops();
|
||||
const char* ternary_ops();
|
||||
const char* reduction();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
|
||||
const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* ternary();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
81
mlx/backend/metal/jit/indexing.h
Normal file
81
mlx/backend/metal/jit/indexing.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
src_strides,
|
||||
src_ndim,
|
||||
slice_sizes,
|
||||
axes,
|
||||
idxs,
|
||||
index,
|
||||
grid_dim);
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
upd_strides,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
)";
|
||||
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
||||
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
||||
Reference in New Issue
Block a user