JIT compile option for binary minimization (#1091)

* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
This commit is contained in:
Awni Hannun
2024-05-22 12:57:13 -07:00
committed by GitHub
parent d568c7ee36
commit 226748b3e7
56 changed files with 3153 additions and 2605 deletions

View File

@@ -0,0 +1,87 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,21 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduction();
const char* gather();
const char* scatter();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* ternary();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,81 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
src_strides,
src_ndim,
slice_sizes,
axes,
idxs,
index,
grid_dim);
}}
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
}}
[[kernel]] void scatter{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
updates,
out,
upd_shape,
upd_strides,
upd_ndim,
upd_size,
out_shape,
out_strides,
out_ndim,
axes,
idxs,
gid);
}}
)";

View File

@@ -0,0 +1,80 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";