Compare commits

...

4 Commits

Author SHA1 Message Date
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
22 changed files with 505 additions and 187 deletions

View File

@@ -17,6 +17,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Enable defining device lambda functions. # Enable defining device lambda functions.
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")

View File

@@ -5,9 +5,17 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/gpu/slicing.h"
#if defined(MLX_USE_CUDA)
#include <nvtx3/nvtx3.hpp>
#endif
#include <cassert> #include <cassert>
#if defined(MLX_USE_CUDA)
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
#else
#define MLX_PROFILER_RANGE(message) #define MLX_PROFILER_RANGE(message)
#endif
namespace mlx::core { namespace mlx::core {

View File

@@ -31,13 +31,13 @@ std::string get_kernel_name(
kname = "ss"; kname = "ss";
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
kname = (large ? "sv2" : "sv"); kname = "sv";
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
kname = (large ? "vs2" : "vs"); kname = "vs";
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
kname = (large ? "vv2" : "vv"); kname = "vv";
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname = "g"; kname = "g";
@@ -51,6 +51,13 @@ std::string get_kernel_name(
} }
break; break;
} }
if (bopt != BinaryOpType::General && bopt != BinaryOpType::ScalarScalar) {
if (large) {
kname += "2";
} else if (work_per_thread > 1) {
kname += "n";
}
}
concatenate(kname, "_", op, type_to_name(a)); concatenate(kname, "_", op, type_to_name(a));
return kname; return kname;
} }
@@ -90,7 +97,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > UINT32_MAX; large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype()); work_per_thread = get_work_per_thread(a.dtype(), out.data_size());
} }
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);

View File

@@ -278,7 +278,21 @@ void Compiled::eval_gpu(
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ work_per_thread); /* work_per_thread = */ 1);
if (work_per_thread > 1) {
build_kernel(
kernel,
kernel_lib_ + "_contiguous_n",
inputs_,
outputs_,
tape_,
is_constant_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
}
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_large", kernel_lib_ + "_contiguous_large",
@@ -358,12 +372,20 @@ void Compiled::eval_gpu(
int ndim = shape.size(); int ndim = shape.size();
bool dynamic = ndim >= 8; bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
int work_per_thread = 1;
if (!contiguous) { if (!contiguous) {
if (dynamic) { if (dynamic) {
kernel_name += "dynamic"; kernel_name += "dynamic";
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
} else {
work_per_thread =
get_work_per_thread(outputs[0].dtype(), outputs[0].data_size());
if (work_per_thread > 1 && !large) {
kernel_name += "_n";
}
} }
if (large) { if (large) {
kernel_name += "_large"; kernel_name += "_large";
@@ -420,7 +442,6 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
@@ -433,7 +454,6 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2; int pow2;

View File

@@ -55,10 +55,10 @@ void copy_gpu_inplace(
std::string kernel_name; std::string kernel_name;
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
kernel_name = (large ? "s2" : "s"); kernel_name = large ? "s2" : "s";
break; break;
case CopyType::Vector: case CopyType::Vector:
kernel_name = (large ? "v2" : "v"); kernel_name = large ? "v2" : "v";
break; break;
case CopyType::General: case CopyType::General:
kernel_name = "g"; kernel_name = "g";
@@ -85,7 +85,10 @@ void copy_gpu_inplace(
} }
} }
} else { } else {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
if (work_per_thread > 1) {
kernel_name += "n";
}
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
int work_per_thread = get_work_per_thread(out.dtype(), out.data_size());
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s");
type_to_name(val) + type_to_name(out); concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread); size_t nthreads = ceildiv(out.data_size(), work_per_thread);
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {

View File

@@ -41,7 +41,11 @@ MTL::ComputePipelineState* get_unary_kernel(
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary()); concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source += kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op); get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op);
}
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -59,11 +63,8 @@ void append_binary_kernels(
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::string& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 7> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"}, {"vs2", "binary_vs2"},
{"sv2", "binary_sv2"}, {"sv2", "binary_sv2"},
{"vv2", "binary_vv2"}, {"vv2", "binary_vv2"},
@@ -78,6 +79,22 @@ void append_binary_kernels(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
} }
kernel_source += get_template_definition(
"vs_" + lib_name, "binary_vs", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "binary_sv", in_t, out_t, op, 1);
kernel_source += get_template_definition(
"vv_" + lib_name, "binary_vv", in_t, out_t, op, 1);
if (get_work_per_thread(in_type) > 1) {
kernel_source += get_template_definition(
"vsn_" + lib_name, "binary_vs", in_t, out_t, op);
kernel_source += get_template_definition(
"svn_" + lib_name, "binary_sv", in_t, out_t, op);
kernel_source += get_template_definition(
"vvn_" + lib_name, "binary_vv", in_t, out_t, op);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -133,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type); auto t_str = get_type_string(type);
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"}, {"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"}, {"g2large", "ternary_g_nd2"},
@@ -144,6 +160,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op); get_template_definition(name + "_" + lib_name, func, t_str, op);
} }
if (get_work_per_thread(type) > 1) {
kernel_source +=
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
}
kernel_source +=
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -170,15 +193,22 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += metal::copy(); kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
kernel_source += kernel_source += get_template_definition(
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); "s_" + lib_name, "copy_s", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type);
kernel_source += kernel_source += get_template_definition(
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); "v_" + lib_name, "copy_v", in_type, out_type, 1);
kernel_source += kernel_source +=
get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type);
if (get_work_per_thread(out.dtype()) > 1) {
kernel_source += get_template_definition(
"sn_" + lib_name, "copy_s", in_type, out_type);
kernel_source += get_template_definition(
"vn_" + lib_name, "copy_v", in_type, out_type);
}
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(

View File

@@ -17,8 +17,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[0], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} }
} }
@@ -30,8 +36,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[index + i], b[0]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} }
} }
@@ -43,8 +55,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
c[index + i] = Op()(a[index + i], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} }
} }
@@ -57,8 +75,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[0], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} }
} }
@@ -71,8 +95,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[offset + i], b[0]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} }
} }
@@ -85,8 +115,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
c[offset + i] = Op()(a[offset + i], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} }
} }

View File

@@ -9,11 +9,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op) \
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -26,15 +31,19 @@
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_work_per_thread(op, tname, itype, otype)
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_base(op, int64, int64_t, int64_t)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
@@ -44,7 +53,7 @@
#define instantiate_binary_types(op) \ #define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \ instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t)\
instantiate_binary_float(op) instantiate_binary_float(op)
#define instantiate_binary_types_bool(op) \ #define instantiate_binary_types_bool(op) \
@@ -52,15 +61,15 @@
instantiate_binary_all(op, uint8, uint8_t, bool) \ instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \ instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \ instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \ instantiate_binary_base(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \ instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \ instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \ instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \ instantiate_binary_base(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \ instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \ instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool) instantiate_binary_base(op, complex64, complex64_t, bool)
instantiate_binary_types(Add) instantiate_binary_types(Add)
instantiate_binary_types(Divide) instantiate_binary_types(Divide)
@@ -71,7 +80,7 @@ instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual) instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual) instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp) instantiate_binary_float(LogAddExp)
instantiate_binary_all(LogAddExp, complex64, complex64_t, complex64_t) instantiate_binary_base(LogAddExp, complex64, complex64_t, complex64_t)
instantiate_binary_types(Maximum) instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum) instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply) instantiate_binary_types(Multiply)
@@ -84,7 +93,7 @@ instantiate_binary_float(ArcTan2)
instantiate_binary_all(NaNEqual, float16, half, bool) instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool) instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) instantiate_binary_base(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(LogicalOr, bool_, bool, bool) instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool) instantiate_binary_all(LogicalAnd, bool_, bool, bool)

View File

@@ -21,10 +21,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[0], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[0], b[index + i]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -37,10 +45,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[index + i], b[0]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[index + i], b[0]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -53,10 +69,18 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
auto out = Op()(a[index + i], b[index + i]); for (int i = 0; index + i < size; ++i) {
c[index + i] = out[0]; auto out = Op()(a[index + i], b[index + i]);
d[index + i] = out[1]; c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} }
} }
@@ -69,11 +93,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[0], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[0], b[offset + i]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }
@@ -86,11 +118,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[offset + i], b[0]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[offset + i], b[0]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }
@@ -103,11 +143,19 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
auto out = Op()(a[offset + i], b[offset + i]); for (int i = 0; offset + i < size; ++i) {
c[offset + i] = out[0]; auto out = Op()(a[offset + i], b[offset + i]);
d[offset + i] = out[1]; c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} }
} }

View File

@@ -7,11 +7,16 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_work_per_thread(op, tname, itype, otype) \
instantiate_kernel("svn_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vsn_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vvn_" #op #tname, binary_vv, itype, otype, op)
#define instantiate_binary_base(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op, 1) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op, 1) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op, 1) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
@@ -24,22 +29,26 @@
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_binary_base(op, tname, itype, otype) \
instantiate_binary_work_per_thread(op, tname, itype, otype)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \ instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_types(op) \ #define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \ instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ instantiate_binary_base(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \ instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \ instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \ instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \ instantiate_binary_base(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ instantiate_binary_base(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op) instantiate_binary_float(op)
instantiate_binary_types(DivMod) // clang-format on instantiate_binary_types(DivMod) // clang-format on

View File

@@ -1,52 +1,76 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
dst[index + i] = static_cast<U>(src[0]); for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[0]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
dst[index + i] = static_cast<U>(src[index + i]); for (int i = 0; index + i < size; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[index + i] = static_cast<U>(src[index + i]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_s2( [[kernel]] void copy_s2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
dst[offset + i] = static_cast<U>(src[0]); for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[0]);
}
} }
} }
template <typename T, typename U, int N = WorkPerThread<T>::n> template <typename T, typename U, int N = WorkPerThread<U>::n>
[[kernel]] void copy_v2( [[kernel]] void copy_v2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
dst[offset + i] = static_cast<U>(src[offset + i]); for (int i = 0; offset + i < size; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
dst[offset + i] = static_cast<U>(src[offset + i]);
}
} }
} }

View File

@@ -4,9 +4,13 @@
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_work_per_thread(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("vn_copy" #tname, copy_v, itype, otype)
#define instantiate_copy_base(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
@@ -18,6 +22,10 @@
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy_base(tname, itype, otype) \
instantiate_copy_work_per_thread(tname, itype, otype)
#define instantiate_copy_same(tname, type) \ #define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
@@ -42,15 +50,15 @@
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \ instantiate_copy_base(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \ instantiate_copy_base(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t) instantiate_copy_base(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool) instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t) instantiate_copy_itype(uint8, uint8_t)

View File

@@ -9,8 +9,14 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); for (int i = 0; index + i < size; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[index + i] = Op()(a[index + i], b[index + i], c[index + i]);
}
} }
} }
@@ -23,9 +29,15 @@ template <typename T, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); for (int i = 0; offset + i < size; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]);
}
} }
} }

View File

@@ -8,8 +8,8 @@
#include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \ #define instantiate_ternary_base(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op, 1) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
@@ -20,19 +20,23 @@
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("vn_" #op #tname, ternary_v, type, op) \
instantiate_ternary_base(op, tname, type)
#define instantiate_ternary_types(op) \ #define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, bool_, bool) \
instantiate_ternary_all(op, uint8, uint8_t) \ instantiate_ternary_all(op, uint8, uint8_t) \
instantiate_ternary_all(op, uint16, uint16_t) \ instantiate_ternary_all(op, uint16, uint16_t) \
instantiate_ternary_all(op, uint32, uint32_t) \ instantiate_ternary_all(op, uint32, uint32_t) \
instantiate_ternary_all(op, uint64, uint64_t) \ instantiate_ternary_base(op, uint64, uint64_t) \
instantiate_ternary_all(op, int8, int8_t) \ instantiate_ternary_all(op, int8, int8_t) \
instantiate_ternary_all(op, int16, int16_t) \ instantiate_ternary_all(op, int16, int16_t) \
instantiate_ternary_all(op, int32, int32_t) \ instantiate_ternary_all(op, int32, int32_t) \
instantiate_ternary_all(op, int64, int64_t) \ instantiate_ternary_base(op, int64, int64_t) \
instantiate_ternary_all(op, float16, half) \ instantiate_ternary_all(op, float16, half) \
instantiate_ternary_all(op, float32, float) \ instantiate_ternary_all(op, float32, float) \
instantiate_ternary_all(op, bfloat16, bfloat16_t) \ instantiate_ternary_all(op, bfloat16, bfloat16_t) \
instantiate_ternary_all(op, complex64, complex64_t) // clang-format on instantiate_ternary_base(op, complex64, complex64_t) // clang-format on
instantiate_ternary_types(Select) instantiate_ternary_types(Select)

View File

@@ -7,8 +7,14 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant uint& size, constant uint& size,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
index *= N; index *= N;
for (int i = 0; i < N && (index + i) < size; ++i) { if (N > 1 && index + N > size) {
out[index + i] = Op()(in[index + i]); for (int i = 0; index + i < size; ++i) {
out[index + i] = Op()(in[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[index + i] = Op()(in[index + i]);
}
} }
} }
@@ -19,9 +25,15 @@ template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
constant int64_t& size, constant int64_t& size,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
for (int i = 0; i < N && (offset + i) < size; ++i) { if (N > 1 && offset + N > size) {
out[offset + i] = Op()(in[offset + i]); for (int i = 0; offset + i < size; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
out[offset + i] = Op()(in[offset + i]);
}
} }
} }

View File

@@ -5,31 +5,41 @@
#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ #define instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ instantiate_kernel("vn_" #op #in_tname #out_tname, unary_v, in_type, out_type, op)
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \ #define instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op, 1) \
instantiate_kernel( \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) "gn4large_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_base(op, in_tname, out_tname, in_type, out_type) \
instantiate_unary_work_per_thread(op, in_tname, out_tname, in_type, out_type)
#define instantiate_unary_all_same(op, tname, type) \ #define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type) instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_base_same(op, tname, type) \
instantiate_unary_base(op, tname, tname, type, type)
#define instantiate_unary_float(op) \ #define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t) instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_int(op) \ #define instantiate_unary_int(op) \
instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \ instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \ instantiate_unary_base_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \ instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \ instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \ instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) instantiate_unary_base_same(op, int64, int64_t)
#define instantiate_unary_types(op) \ #define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \ instantiate_unary_all_same(op, bool_, bool) \
@@ -68,29 +78,29 @@ instantiate_unary_float(Tanh)
instantiate_unary_float(Round) instantiate_unary_float(Round)
instantiate_unary_int(BitwiseInvert) instantiate_unary_int(BitwiseInvert)
instantiate_unary_all_same(Abs, complex64, complex64_t) instantiate_unary_base_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(ArcCos, complex64, complex64_t) instantiate_unary_base_same(ArcCos, complex64, complex64_t)
instantiate_unary_all_same(ArcSin, complex64, complex64_t) instantiate_unary_base_same(ArcSin, complex64, complex64_t)
instantiate_unary_all_same(ArcTan, complex64, complex64_t) instantiate_unary_base_same(ArcTan, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t) instantiate_unary_base_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t) instantiate_unary_base_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t) instantiate_unary_base_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t) instantiate_unary_base_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Log, complex64, complex64_t) instantiate_unary_base_same(Log, complex64, complex64_t)
instantiate_unary_all_same(Log1p, complex64, complex64_t) instantiate_unary_base_same(Log1p, complex64, complex64_t)
instantiate_unary_all_same(Log2, complex64, complex64_t) instantiate_unary_base_same(Log2, complex64, complex64_t)
instantiate_unary_all_same(Log10, complex64, complex64_t) instantiate_unary_base_same(Log10, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t) instantiate_unary_base_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t) instantiate_unary_base_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t) instantiate_unary_base_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t) instantiate_unary_base_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Square, complex64, complex64_t) instantiate_unary_base_same(Square, complex64, complex64_t)
instantiate_unary_all_same(Sqrt, complex64, complex64_t) instantiate_unary_base_same(Sqrt, complex64, complex64_t)
instantiate_unary_all_same(Rsqrt, complex64, complex64_t) instantiate_unary_base_same(Rsqrt, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t) instantiate_unary_base_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t) instantiate_unary_base_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t) instantiate_unary_base_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float) instantiate_unary_base(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float) instantiate_unary_base(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2; work_per_thread = large ? 4 : 2;
} else { } else {
large = out.data_size() > INT32_MAX; large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype()); work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
} }
std::string kernel_name; std::string kernel_name;
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
} }
} else if (large) { } else if (large) {
kernel_name = "v2"; kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else { } else {
kernel_name = "v"; kernel_name = "v";
} }

View File

@@ -43,8 +43,8 @@ void unary_op_gpu_inplace(
int work_per_thread; int work_per_thread;
std::string kernel_name; std::string kernel_name;
if (contig) { if (contig) {
work_per_thread = get_work_per_thread(in.dtype()); work_per_thread = get_work_per_thread(in.dtype(), in.data_size());
kernel_name = (large ? "v2" : "v"); kernel_name = (large ? "v2" : (work_per_thread > 1 ? "vn" : "v"));
} else { } else {
work_per_thread = large ? 4 : 1; work_per_thread = large ? 4 : 1;
kernel_name = "gn" + std::to_string(work_per_thread); kernel_name = "gn" + std::to_string(work_per_thread);

View File

@@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) { inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size()); return std::max(1, 8 / dtype.size());
} }
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) { inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m; return (n + m - 1) / m;

View File

@@ -193,7 +193,7 @@ class Module(dict):
) )
if len(weights) != 0: if len(weights) != 0:
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights), strict=False)
return self return self
def save_weights(self, file: str): def save_weights(self, file: str):
@@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module: def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@@ -305,7 +305,9 @@ class Module(dict):
Args: Args:
parameters (dict): A complete or partial dictionary of the modules parameters (dict): A complete or partial dictionary of the modules
parameters. parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns: Returns:
The module instance after updating the parameters. The module instance after updating the parameters.
""" """
@@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k] current_value = dst[k]
new_value = parameters[k] new_value = parameters[k]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list): elif isinstance(parameters, list):
for i in range(len(parameters)): for i in range(len(parameters)):
current_value = dst[i] current_value = dst[i]
new_value = parameters[i] new_value = parameters[i]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters) apply(self, parameters)
return self return self
@@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self return self
def update_modules(self, modules: dict) -> Module: def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers. programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be similar to :meth:`modules`. Only the provided locations will be
updated. updated.
Args: Args:
modules (dict): A complete or partial dictionary of the modules modules (dict): A complete or partial dictionary of the module's
submodules. submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns: Returns:
The module instance after updating the submodules. The module instance after updating the submodules.
""" """
@@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(dst)):
current_value = dst[i] current_value = dst[i]
@@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules) apply(self, modules)
return self return self

View File

@@ -54,5 +54,9 @@ target_link_libraries(core PRIVATE mlx)
target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
else()
target_link_options(core PRIVATE -Wl,-rpath,\$ORIGIN/lib)
endif()
endif() endif()

View File

@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,)) x = mx.zeros((3,))
mx.grad(loss_fn)(model) mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):