mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix bw for elementwise ops (#2151)
* fix bw for elementwise ops * add compile * fix * fix * fix * fix
This commit is contained in:
parent
9c5e7da507
commit
825124af8f
@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
|
|||||||
work_per_thread = large ? 4 : 2;
|
work_per_thread = large ? 4 : 2;
|
||||||
} else {
|
} else {
|
||||||
large = out.data_size() > UINT32_MAX;
|
large = out.data_size() > UINT32_MAX;
|
||||||
work_per_thread = 1;
|
work_per_thread = get_work_per_thread(a.dtype());
|
||||||
}
|
}
|
||||||
std::string kernel_name =
|
std::string kernel_name =
|
||||||
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
|
||||||
@ -137,13 +137,20 @@ void binary_op_gpu_inplace(
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims;
|
||||||
: MTL::Size(nthreads, 1, 1);
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
|
||||||
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,7 @@ inline void build_kernel(
|
|||||||
cnt++);
|
cnt++);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
||||||
if (add_indices) {
|
if (add_indices) {
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||||
@ -83,6 +84,9 @@ inline void build_kernel(
|
|||||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||||
|
} else {
|
||||||
|
os += fmt::format(
|
||||||
|
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
|
||||||
}
|
}
|
||||||
if (dynamic_dims) {
|
if (dynamic_dims) {
|
||||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||||
@ -92,13 +96,14 @@ inline void build_kernel(
|
|||||||
os += " uint3 pos [[thread_position_in_grid]],\n";
|
os += " uint3 pos [[thread_position_in_grid]],\n";
|
||||||
os += " uint3 grid [[threads_per_grid]]) {\n";
|
os += " uint3 grid [[threads_per_grid]]) {\n";
|
||||||
|
|
||||||
std::string idx_type = use_big_index ? "int64_t" : "uint";
|
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||||
if (contiguous && use_big_index) {
|
if (contiguous && use_big_index) {
|
||||||
// This is only used for contiguous kernels which don't have
|
// This is only used for contiguous kernels which don't have
|
||||||
// a third grid dimension
|
// a third grid dimension
|
||||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
|
||||||
|
} else if (contiguous) {
|
||||||
|
os += " uint index = N_ * pos.x;\n";
|
||||||
} else if (work_per_thread > 1) {
|
} else if (work_per_thread > 1) {
|
||||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
|
||||||
os += fmt::format(
|
os += fmt::format(
|
||||||
" int xshape = output_shape[{0}];\n",
|
" int xshape = output_shape[{0}];\n",
|
||||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||||
@ -110,6 +115,9 @@ inline void build_kernel(
|
|||||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||||
idx_type);
|
idx_type);
|
||||||
}
|
}
|
||||||
|
if (work_per_thread > 1 && contiguous) {
|
||||||
|
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
// Read constant / contiguous inputs in tmps
|
// Read constant / contiguous inputs in tmps
|
||||||
std::vector<array> nc_inputs;
|
std::vector<array> nc_inputs;
|
||||||
@ -193,7 +201,7 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Open per-thread loop
|
// Open per-thread loop
|
||||||
if (work_per_thread > 1) {
|
if (work_per_thread > 1 && !contiguous) {
|
||||||
os +=
|
os +=
|
||||||
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
|
||||||
}
|
}
|
||||||
@ -272,6 +280,7 @@ void Compiled::eval_gpu(
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto lib = d.get_library(kernel_lib_, [&]() {
|
auto lib = d.get_library(kernel_lib_, [&]() {
|
||||||
|
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
|
||||||
std::string kernel = metal::utils();
|
std::string kernel = metal::utils();
|
||||||
concatenate(
|
concatenate(
|
||||||
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
|
||||||
@ -284,7 +293,9 @@ void Compiled::eval_gpu(
|
|||||||
constant_ids_,
|
constant_ids_,
|
||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false);
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ false,
|
||||||
|
/* work_per_thread = */ work_per_thread);
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous_large",
|
kernel_lib_ + "_contiguous_large",
|
||||||
@ -295,7 +306,8 @@ void Compiled::eval_gpu(
|
|||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ true);
|
/* use_big_index = */ true,
|
||||||
|
/* work_per_thread = */ work_per_thread);
|
||||||
for (int i = 1; i < 8; i++) {
|
for (int i = 1; i < 8; i++) {
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
@ -468,6 +480,13 @@ void Compiled::eval_gpu(
|
|||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
compute_encoder.set_vector_bytes(strides[0], cnt++);
|
||||||
compute_encoder.set_vector_bytes(shape, cnt++);
|
compute_encoder.set_vector_bytes(shape, cnt++);
|
||||||
|
} else {
|
||||||
|
auto size = outputs[0].data_size();
|
||||||
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(size, cnt++);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(size, cnt++);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put the number of dims in if it is dynamic
|
// Put the number of dims in if it is dynamic
|
||||||
@ -477,12 +496,13 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
size_t nthreads = outputs[0].data_size();
|
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
||||||
|
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
|
|
||||||
MTL::Size grid_dims = large
|
MTL::Size grid_dims = large
|
||||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
? get_2d_grid_dims(
|
||||||
|
outputs[0].shape(), outputs[0].strides(), work_per_thread)
|
||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
|
@ -104,6 +104,8 @@ void copy_gpu_inplace(
|
|||||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
work_per_thread = get_work_per_thread(in.dtype());
|
||||||
}
|
}
|
||||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||||
@ -165,13 +167,19 @@ void copy_gpu_inplace(
|
|||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims;
|
||||||
: MTL::Size(nthreads, 1, 1);
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||||
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,14 +222,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
compute_encoder.set_input_array(val, 0);
|
compute_encoder.set_input_array(val, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
|
int work_per_thread = get_work_per_thread(val.dtype());
|
||||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims;
|
||||||
: MTL::Size(nthreads, 1, 1);
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
|
||||||
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(out.data_size(), 2);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,64 +9,85 @@ template <typename T, typename U, typename Op>
|
|||||||
c[index] = Op()(a[0], b[0]);
|
c[index] = Op()(a[0], b[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_sv(
|
[[kernel]] void binary_sv(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
c[index] = Op()(a[0], b[index]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[0], b[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vs(
|
[[kernel]] void binary_vs(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
c[index] = Op()(a[index], b[0]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vv(
|
[[kernel]] void binary_vv(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
c[index] = Op()(a[index], b[index]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
c[index + i] = Op()(a[index + i], b[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_sv2(
|
[[kernel]] void binary_sv2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
c[offset] = Op()(a[0], b[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[0], b[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vs2(
|
[[kernel]] void binary_vs2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
c[offset] = Op()(a[offset], b[0]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vv2(
|
[[kernel]] void binary_vv2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
c[offset] = Op()(a[offset], b[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
c[offset + i] = Op()(a[offset + i], b[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||||
|
@ -12,82 +12,103 @@ template <typename T, typename U, typename Op>
|
|||||||
d[index] = out[1];
|
d[index] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_sv(
|
[[kernel]] void binary_sv(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto out = Op()(a[0], b[index]);
|
index *= N;
|
||||||
c[index] = out[0];
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
d[index] = out[1];
|
auto out = Op()(a[0], b[index + i]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vs(
|
[[kernel]] void binary_vs(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto out = Op()(a[index], b[0]);
|
index *= N;
|
||||||
c[index] = out[0];
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
d[index] = out[1];
|
auto out = Op()(a[index + i], b[0]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vv(
|
[[kernel]] void binary_vv(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto out = Op()(a[index], b[index]);
|
index *= N;
|
||||||
c[index] = out[0];
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
d[index] = out[1];
|
auto out = Op()(a[index + i], b[index + i]);
|
||||||
|
c[index + i] = out[0];
|
||||||
|
d[index + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_sv2(
|
[[kernel]] void binary_sv2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
auto out = Op()(a[0], b[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
c[offset] = out[0];
|
auto out = Op()(a[0], b[offset + i]);
|
||||||
d[offset] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vs2(
|
[[kernel]] void binary_vs2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
auto out = Op()(a[offset], b[0]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
c[offset] = out[0];
|
auto out = Op()(a[offset + i], b[0]);
|
||||||
d[offset] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void binary_vv2(
|
[[kernel]] void binary_vv2(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device U* c,
|
device U* c,
|
||||||
device U* d,
|
device U* d,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
auto out = Op()(a[offset], b[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
c[offset] = out[0];
|
auto out = Op()(a[offset + i], b[offset + i]);
|
||||||
d[offset] = out[1];
|
c[offset + i] = out[0];
|
||||||
|
d[offset + i] = out[1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
template <typename T, typename U, typename Op, typename IdxT = int64_t>
|
||||||
|
@ -1,39 +1,53 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void copy_s(
|
[[kernel]] void copy_s(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
dst[index] = static_cast<U>(src[0]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void copy_v(
|
[[kernel]] void copy_v(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
dst[index] = static_cast<U>(src[index]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void copy_s2(
|
[[kernel]] void copy_s2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
dst[offset] = static_cast<U>(src[0]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void copy_v2(
|
[[kernel]] void copy_v2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
dst[offset] = static_cast<U>(src[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename IdxT = int64_t>
|
template <typename T, typename U, typename IdxT = int64_t>
|
||||||
|
@ -1,25 +1,32 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void ternary_v(
|
[[kernel]] void ternary_v(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device const T* c,
|
device const T* c,
|
||||||
device T* d,
|
device T* d,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
d[index] = Op()(a[index], b[index], c[index]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void ternary_v2(
|
[[kernel]] void ternary_v2(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
device const T* b,
|
device const T* b,
|
||||||
device const T* c,
|
device const T* c,
|
||||||
device T* d,
|
device T* d,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, typename IdxT = int64_t>
|
template <typename T, typename Op, typename IdxT = int64_t>
|
||||||
|
@ -1,21 +1,28 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void unary_v(
|
[[kernel]] void unary_v(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device U* out,
|
device U* out,
|
||||||
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
out[index] = Op()(in[index]);
|
index *= N;
|
||||||
|
for (int i = 0; i < N && (index + i) < size; ++i) {
|
||||||
|
out[index + i] = Op()(in[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
||||||
[[kernel]] void unary_v2(
|
[[kernel]] void unary_v2(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device U* out,
|
device U* out,
|
||||||
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = index.x + grid_dim.x * int64_t(index.y);
|
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
out[offset] = Op()(in[offset]);
|
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
||||||
|
out[offset + i] = Op()(in[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
|
@ -15,6 +15,14 @@
|
|||||||
|
|
||||||
typedef half float16_t;
|
typedef half float16_t;
|
||||||
|
|
||||||
|
// Work per thread values for different types. The values here are expected to
|
||||||
|
// match get_work_per_thread in mlx/backend/metal/utils.h
|
||||||
|
template <typename U>
|
||||||
|
struct WorkPerThread {
|
||||||
|
static_assert(sizeof(U) <= 8, "Type too large");
|
||||||
|
static constexpr int constant n = 8 / sizeof(U);
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
|
|||||||
work_per_thread = large ? 4 : 2;
|
work_per_thread = large ? 4 : 2;
|
||||||
} else {
|
} else {
|
||||||
large = out.data_size() > INT32_MAX;
|
large = out.data_size() > INT32_MAX;
|
||||||
work_per_thread = 1;
|
work_per_thread = get_work_per_thread(b.dtype());
|
||||||
}
|
}
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
@ -106,13 +106,19 @@ void ternary_op_gpu_inplace(
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims;
|
||||||
: MTL::Size(nthreads, 1, 1);
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(out.data_size(), 4);
|
||||||
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(out.data_size(), 4);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,18 +34,19 @@ void unary_op_gpu_inplace(
|
|||||||
};
|
};
|
||||||
auto [shape, strides] = maybe_collapse();
|
auto [shape, strides] = maybe_collapse();
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
|
||||||
bool large;
|
bool large;
|
||||||
if (!contig) {
|
if (!contig) {
|
||||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||||
} else {
|
} else {
|
||||||
large = in.data_size() > UINT32_MAX;
|
large = in.data_size() > UINT32_MAX;
|
||||||
}
|
}
|
||||||
int work_per_thread = !contig && large ? 4 : 1;
|
int work_per_thread;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
if (contig) {
|
if (contig) {
|
||||||
|
work_per_thread = get_work_per_thread(in.dtype());
|
||||||
kernel_name = (large ? "v2" : "v");
|
kernel_name = (large ? "v2" : "v");
|
||||||
} else {
|
} else {
|
||||||
|
work_per_thread = large ? 4 : 1;
|
||||||
kernel_name = "gn" + std::to_string(work_per_thread);
|
kernel_name = "gn" + std::to_string(work_per_thread);
|
||||||
if (large) {
|
if (large) {
|
||||||
kernel_name += "large";
|
kernel_name += "large";
|
||||||
@ -75,12 +76,20 @@ void unary_op_gpu_inplace(
|
|||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
|
size_t nthreads = ceildiv(in.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
|
MTL::Size grid_dims;
|
||||||
: MTL::Size(nthreads, 1, 1);
|
if (large) {
|
||||||
|
compute_encoder.set_bytes<int64_t>(in.data_size(), 2);
|
||||||
|
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
|
||||||
|
} else {
|
||||||
|
compute_encoder.set_bytes<int>(in.data_size(), 2);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) {
|
|||||||
concatenate(acc, args...);
|
concatenate(acc, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int get_work_per_thread(Dtype dtype) {
|
||||||
|
return std::max(1, 8 / dtype.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t ceildiv(size_t n, size_t m) {
|
||||||
|
return (n + m - 1) / m;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user