mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
compile and copy
This commit is contained in:
parent
d0ebd18d7d
commit
ba8748b12e
@ -278,7 +278,21 @@ void Compiled::eval_gpu(
|
|||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false,
|
/* dynamic_dims = */ false,
|
||||||
/* use_big_index = */ false,
|
/* use_big_index = */ false,
|
||||||
/* work_per_thread = */ work_per_thread);
|
/* work_per_thread = */ 1);
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
build_kernel(
|
||||||
|
kernel,
|
||||||
|
kernel_lib_ + "_contiguous_n",
|
||||||
|
inputs_,
|
||||||
|
outputs_,
|
||||||
|
tape_,
|
||||||
|
is_constant_,
|
||||||
|
/* contiguous = */ true,
|
||||||
|
/* ndim = */ 0,
|
||||||
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ false,
|
||||||
|
/* work_per_thread = */ work_per_thread);
|
||||||
|
}
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous_large",
|
kernel_lib_ + "_contiguous_large",
|
||||||
@ -358,12 +372,20 @@ void Compiled::eval_gpu(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
bool dynamic = ndim >= 8;
|
bool dynamic = ndim >= 8;
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
|
int work_per_thread = 1;
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
if (dynamic) {
|
if (dynamic) {
|
||||||
kernel_name += "dynamic";
|
kernel_name += "dynamic";
|
||||||
} else {
|
} else {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
|
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
||||||
|
} else {
|
||||||
|
work_per_thread =
|
||||||
|
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
|
||||||
|
if (work_per_thread > 1 && !large) {
|
||||||
|
kernel_name += "_n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (large) {
|
if (large) {
|
||||||
kernel_name += "_large";
|
kernel_name += "_large";
|
||||||
@ -420,7 +442,6 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int work_per_thread = get_work_per_thread(outputs[0].dtype());
|
|
||||||
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
|
||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
@ -433,7 +454,6 @@ void Compiled::eval_gpu(
|
|||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||||
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
|
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
int pow2;
|
int pow2;
|
||||||
|
@ -55,10 +55,10 @@ void copy_gpu_inplace(
|
|||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
kernel_name = (large ? "s2" : "s");
|
kernel_name = large ? "s2" : "s";
|
||||||
break;
|
break;
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
kernel_name = (large ? "v2" : "v");
|
kernel_name = large ? "v2" : "v";
|
||||||
break;
|
break;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
kernel_name = "g";
|
kernel_name = "g";
|
||||||
@ -85,7 +85,10 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
work_per_thread = get_work_per_thread(in.dtype());
|
work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||||
|
if (work_per_thread > 1) {
|
||||||
|
kernel_name += "n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||||
@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
bool large = out.data_size() > UINT32_MAX;
|
bool large = out.data_size() > UINT32_MAX;
|
||||||
|
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
|
std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
|
||||||
type_to_name(val) + type_to_name(out);
|
concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
|
||||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
|
|||||||
compute_encoder.set_input_array(val, 0);
|
compute_encoder.set_input_array(val, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
int work_per_thread = get_work_per_thread(val.dtype());
|
|
||||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
|
@ -180,12 +180,17 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
kernel_source += metal::copy();
|
kernel_source += metal::copy();
|
||||||
auto in_type = get_type_string(in.dtype());
|
auto in_type = get_type_string(in.dtype());
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"s_" + lib_name, "copy_s", in_type, out_type, 1);
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
|
get_template_definition("sn_" + lib_name, "copy_s", in_type, out_type);
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
"v_" + lib_name, "copy_v", in_type, out_type, 1);
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
get_template_definition("vn_" + lib_name, "copy_v", in_type, out_type);
|
||||||
|
|
||||||
kernel_source +=
|
kernel_source +=
|
||||||
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
|
||||||
|
|
||||||
|
@ -1,52 +1,76 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_s(
|
[[kernel]] void copy_s(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
dst[index + i] = static_cast<U>(src[0]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_v(
|
[[kernel]] void copy_v(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant uint& size,
|
constant uint& size,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
index *= N;
|
index *= N;
|
||||||
for (int i = 0; i < N && (index + i) < size; ++i) {
|
if (N > 1 && index + N > size) {
|
||||||
dst[index + i] = static_cast<U>(src[index + i]);
|
for (int i = 0; index + i < size; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[index + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[index + i] = static_cast<U>(src[index + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_s2(
|
[[kernel]] void copy_s2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
dst[offset + i] = static_cast<U>(src[0]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int N = WorkPerThread<T>::n>
|
template <typename T, typename U, int N = WorkPerThread<U>::n>
|
||||||
[[kernel]] void copy_v2(
|
[[kernel]] void copy_v2(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst [[buffer(1)]],
|
device U* dst [[buffer(1)]],
|
||||||
constant int64_t& size,
|
constant int64_t& size,
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
||||||
for (int i = 0; i < N && (offset + i) < size; ++i) {
|
if (N > 1 && offset + N > size) {
|
||||||
dst[offset + i] = static_cast<U>(src[offset + i]);
|
for (int i = 0; offset + i < size; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
dst[offset + i] = static_cast<U>(src[offset + i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,9 +4,13 @@
|
|||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/copy.h"
|
#include "mlx/backend/metal/kernels/copy.h"
|
||||||
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_work_per_thread(tname, itype, otype) \
|
||||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
|
||||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
|
||||||
|
|
||||||
|
#define instantiate_copy_base(tname, itype, otype) \
|
||||||
|
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
|
||||||
|
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
|
||||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||||
@ -18,6 +22,10 @@
|
|||||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
||||||
|
|
||||||
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
|
instantiate_copy_base(tname, itype, otype) \
|
||||||
|
instantiate_copy_work_per_thread(tname, itype, otype)
|
||||||
|
|
||||||
#define instantiate_copy_same(tname, type) \
|
#define instantiate_copy_same(tname, type) \
|
||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||||
@ -42,15 +50,15 @@
|
|||||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
instantiate_copy_base(itname ##uint64, itype, uint64_t) \
|
||||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
instantiate_copy_base(itname ##int64, itype, int64_t) \
|
||||||
instantiate_copy_all(itname ##float16, itype, half) \
|
instantiate_copy_all(itname ##float16, itype, half) \
|
||||||
instantiate_copy_all(itname ##float32, itype, float) \
|
instantiate_copy_all(itname ##float32, itype, float) \
|
||||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
instantiate_copy_base(itname ##complex64, itype, complex64_t)
|
||||||
|
|
||||||
instantiate_copy_itype(bool_, bool)
|
instantiate_copy_itype(bool_, bool)
|
||||||
instantiate_copy_itype(uint8, uint8_t)
|
instantiate_copy_itype(uint8, uint8_t)
|
||||||
|
Loading…
Reference in New Issue
Block a user