diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 06cf734f0..cf2e1c83f 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -6,6 +6,17 @@ #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" +#define BINARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + binary_op_gpu(inputs, out, get_primitive_string(this)); \ + } + +#define BINARY_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + binary_op_gpu(inputs, outputs, get_primitive_string(this)); \ + } + namespace mlx::core { constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; @@ -61,7 +72,8 @@ void binary_op_gpu_inplace( auto& d = metal::device(s.device); - auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]); + auto kernel = + get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); @@ -188,7 +200,7 @@ void binary_op_gpu_inplace( auto& d = metal::device(s.device); - auto kernel = get_binary_kernel(d, kernel_name, a, out); + auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); bool donate_a = a.data_shared_ptr() == nullptr; @@ -259,102 +271,44 @@ void binary_op_gpu( binary_op_gpu(inputs, out, op, s); } -void Add::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "add"); -} - -void ArcTan2::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "arctan2"); -} +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU_MULTI(DivMod) +BINARY_GPU(Remainder) +BINARY_GPU(Equal) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(LogAddExp) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Subtract) void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: - binary_op_gpu(inputs, out, "bitwise_and"); + binary_op_gpu(inputs, out, get_primitive_string(this)); break; case BitwiseBinary::Or: - binary_op_gpu(inputs, out, "bitwise_or"); + binary_op_gpu(inputs, out, get_primitive_string(this)); break; case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, "bitwise_xor"); + binary_op_gpu(inputs, out, get_primitive_string(this)); break; case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, "left_shift"); + binary_op_gpu(inputs, out, get_primitive_string(this)); break; case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, "right_shift"); + binary_op_gpu(inputs, out, get_primitive_string(this)); break; } } -void Divide::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "div"); -} - -void DivMod::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - binary_op_gpu(inputs, outputs, "divmod"); -} - -void Remainder::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "rem"); -} - -void Equal::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq"); -} - -void Greater::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "ge"); -} - -void GreaterEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "geq"); -} - -void Less::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "le"); -} - -void LessEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "leq"); -} - -void LogicalAnd::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "land"); -} - -void LogicalOr::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "lor"); -} - -void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "lae"); -} - -void Maximum::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "max"); -} - -void Minimum::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "min"); -} - -void Multiply::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "mul"); -} - -void NotEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "neq"); -} - -void Power::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "pow"); -} - -void Subtract::eval_gpu(const std::vector& inputs, array& out) { - binary_op_gpu(inputs, out, "sub"); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 817bf7fe4..5140a8e7a 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -391,16 +391,16 @@ void multi_upload_bluestein_fft( std::vector rstrides(in.ndim(), 1); rstarts[axis] = in.shape(axis) - back_offset; rstrides[axis] = -1; - unary_op_gpu({in}, conj_temp, "conj", s); + unary_op_gpu({in}, conj_temp, "Conjugate", s); slice_gpu(in, slice_temp, rstarts, rstrides, s); concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s); } else if (inverse) { - unary_op_gpu({in}, temp, "conj", s); + unary_op_gpu({in}, temp, "Conjugate", s); } else { temp.copy_shared_buffer(in); } - binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s); + binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s); std::vector> pads; auto padded_shape = out.shape(); @@ -419,7 +419,7 @@ void multi_upload_bluestein_fft( /*inplace=*/false, s); - binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s); + binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s); fft_op( pad_temp, @@ -437,7 +437,7 @@ void multi_upload_bluestein_fft( starts[axis] = plan.bluestein_n - offset - n; slice_gpu(pad_temp1, temp, starts, strides, s); - binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s); + binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); if (real && !inverse) { std::vector rstarts(in.ndim(), 0); @@ -451,11 +451,11 @@ void multi_upload_bluestein_fft( copies.push_back(inv_n); copy_gpu(temp1, temp_float, CopyType::General, s); - binary_op_gpu({temp_float, inv_n}, out, "mul", s); + binary_op_gpu({temp_float, inv_n}, out, "Multiply", s); } else if (inverse) { auto inv_n = array({1.0f / n}, {1}, complex64); - unary_op_gpu({temp1}, temp, "conj", s); - binary_op_gpu({temp, inv_n}, out, "mul", s); + unary_op_gpu({temp1}, temp, "Conjugate", s); + binary_op_gpu({temp, inv_n}, out, "Multiply", s); copies.push_back(inv_n); } else { out.copy_shared_buffer(temp1); diff --git a/mlx/backend/metal/jit/binary.h b/mlx/backend/metal/jit/binary.h deleted file mode 100644 index febc8be6c..000000000 --- a/mlx/backend/metal/jit/binary.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view binary_kernels = R"( -template [[host_name("ss{0}")]] [[kernel]] -void binary_ss<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - uint index [[thread_position_in_grid]]); -template [[host_name("vs{0}")]] [[kernel]] -void binary_vs<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - uint index [[thread_position_in_grid]]); -template [[host_name("sv{0}")]] [[kernel]] -void binary_sv<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - uint index [[thread_position_in_grid]]); -template [[host_name("vv{0}")]] [[kernel]] -void binary_vv<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - uint index [[thread_position_in_grid]]); -template [[host_name("g4{0}")]] [[kernel]] void -binary_g_nd<{1}, {2}, {3}, 4>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const int shape[4], - constant const size_t a_strides[4], - constant const size_t b_strides[4], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("g5{0}")]] [[kernel]] void -binary_g_nd<{1}, {2}, {3}, 5>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const int shape[5], - constant const size_t a_strides[5], - constant const size_t b_strides[5], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); - -template [[host_name("g1{0}")]] [[kernel]] void -binary_g_nd1<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const size_t& a_stride, - constant const size_t& b_stride, - uint index [[thread_position_in_grid]]); -template [[host_name("g2{0}")]] [[kernel]] void -binary_g_nd2<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]); -template [[host_name("g3{0}")]] [[kernel]] void -binary_g_nd3<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); - -template [[host_name("gn{0}")]] [[kernel]] -void binary_g<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -)"; diff --git a/mlx/backend/metal/jit/binary_two.h b/mlx/backend/metal/jit/binary_two.h deleted file mode 100644 index 54b0c6296..000000000 --- a/mlx/backend/metal/jit/binary_two.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view binary_two_kernels = R"( -template [[host_name("ss{0}")]] [[kernel]] -void binary_ss<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - uint index [[thread_position_in_grid]]); -template [[host_name("vs{0}")]] [[kernel]] -void binary_vs<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - uint index [[thread_position_in_grid]]); -template [[host_name("sv{0}")]] [[kernel]] -void binary_sv<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - uint index [[thread_position_in_grid]]); -template [[host_name("vv{0}")]] [[kernel]] -void binary_vv<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - uint index [[thread_position_in_grid]]); - -template [[host_name("g4{0}")]] [[kernel]] void -binary_g_nd<{1}, {2}, {3}, 4>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const int shape[4], - constant const size_t a_strides[4], - constant const size_t b_strides[4], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("g5{0}")]] [[kernel]] void -binary_g_nd<{1}, {2}, {3}, 5>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const int shape[5], - constant const size_t a_strides[5], - constant const size_t b_strides[5], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); - -template [[host_name("g1{0}")]] [[kernel]] void -binary_g_nd1<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const size_t& a_stride, - constant const size_t& b_stride, - uint index [[thread_position_in_grid]]); -template [[host_name("g2{0}")]] [[kernel]] void -binary_g_nd2<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]); -template [[host_name("g3{0}")]] [[kernel]] void -binary_g_nd3<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); - -template [[host_name("gn{0}")]] [[kernel]] -void binary_g<{1}, {2}, {3}>( - device const {1}* a, - device const {1}* b, - device {2}* c, - device {2}* d, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -)"; diff --git a/mlx/backend/metal/jit/ternary.h b/mlx/backend/metal/jit/ternary.h deleted file mode 100644 index 8b49e1311..000000000 --- a/mlx/backend/metal/jit/ternary.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view ternary_kernels = R"( -template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - uint index [[thread_position_in_grid]]); - -template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); - -template [[host_name("g1_{0}")]] [[kernel]] void -ternary_g_nd1<{1}, {2}>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const size_t& a_strides, - constant const size_t& b_strides, - constant const size_t& c_strides, - uint index [[thread_position_in_grid]]); -template [[host_name("g2_{0}")]] [[kernel]] void -ternary_g_nd2<{1}, {2}>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - constant const size_t c_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]); -template [[host_name("g3_{0}")]] [[kernel]] void -ternary_g_nd3<{1}, {2}>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - constant const size_t c_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("g4_{0}")]] [[kernel]] void -ternary_g_nd<{1}, {2}, 4>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const int shape[4], - constant const size_t a_strides[4], - constant const size_t b_strides[4], - constant const size_t c_strides[4], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("g5_{0}")]] [[kernel]] void -ternary_g_nd<{1}, {2}, 5>( - device const bool* a, - device const {1}* b, - device const {1}* c, - device {1}* d, - constant const int shape[5], - constant const size_t a_strides[5], - constant const size_t b_strides[5], - constant const size_t c_strides[5], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -)"; diff --git a/mlx/backend/metal/jit/unary.h b/mlx/backend/metal/jit/unary.h deleted file mode 100644 index d35957fe4..000000000 --- a/mlx/backend/metal/jit/unary.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view unary_kernels = R"( -template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>( - device const {1}* in, - device {1}* out, - uint index [[thread_position_in_grid]]); - -template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>( - device const {1}* in, - device {1}* out, - device const int* in_shape, - device const size_t* in_strides, - device const int& ndim, - uint index [[thread_position_in_grid]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 90495cd72..6791c6685 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,9 +1,8 @@ // Copyright © 2024 Apple Inc. +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" -#include "mlx/backend/metal/jit/binary.h" -#include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/reduce.h" @@ -12,8 +11,6 @@ #include "mlx/backend/metal/jit/sort.h" #include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_gemm.h" -#include "mlx/backend/metal/jit/ternary.h" -#include "mlx/backend/metal/jit/unary.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -46,38 +43,76 @@ MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, - const array& out) { + Dtype out_type, + const std::string op) { std::string lib_name = kernel_name.substr(1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; + auto u_def = get_template_definition( + "v" + lib_name, "unary_v", get_type_string(out_type), op); + auto g_def = get_template_definition( + "g" + lib_name, "unary_g", get_type_string(out_type), op); kernel_source << metal::utils() << metal::unary_ops() << metal::unary() - << fmt::format( - unary_kernels, - lib_name, - get_type_string(out.dtype()), - op_name(out)); + << u_def << g_def; lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); } +void add_binary_kernels( + const std::string lib_name, + Dtype in_type, + Dtype out_type, + const std::string op, + std::ostringstream& kernel_source) { + const std::map kernel_types = { + {"ss", "binary_ss"}, + {"vs", "binary_vs"}, + {"sv", "binary_sv"}, + {"vv", "binary_vv"}, + {"g1", "binary_g_nd1"}, + {"g2", "binary_g_nd2"}, + {"g3", "binary_g_nd3"}, + {"g4", "binary_g_nd"}, + {"g5", "binary_g_nd"}, + {"gn", "binary_g"}, + }; + for (auto [name, func] : kernel_types) { + std::string template_def; + if (name == "g4" || name == "g5") { + int dim = std::stoi(name.substr(1)); + template_def = get_template_definition( + name + lib_name, + func, + get_type_string(in_type), + get_type_string(out_type), + op, + dim); + } else { + template_def = get_template_definition( + name + lib_name, + func, + get_type_string(in_type), + get_type_string(out_type), + op); + } + kernel_source << template_def; + } +} + MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out) { + Dtype in_type, + Dtype out_type, + const std::string op) { std::string lib_name = kernel_name.substr(2); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::binary_ops() << metal::binary() - << fmt::format( - binary_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype()), - op_name(out)); + kernel_source << metal::utils() << metal::binary_ops() << metal::binary(); + add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -86,20 +121,16 @@ MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out) { + Dtype in_type, + Dtype out_type, + const std::string op) { std::string lib_name = kernel_name.substr(2); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::binary_ops() - << metal::binary_two() - << fmt::format( - binary_two_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype()), - op_name(out)); + << metal::binary_two(); + add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -108,17 +139,34 @@ MTL::ComputePipelineState* get_binary_two_kernel( MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, - const array& out) { + Dtype type, + const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary() - << fmt::format( - ternary_kernels, - lib_name, - get_type_string(out.dtype()), - op_name(out)); + const std::map kernel_types = { + {"v", "ternary_v"}, + {"g", "ternary_g"}, + {"g1", "ternary_g_nd1"}, + {"g2", "ternary_g_nd2"}, + {"g3", "ternary_g_nd3"}, + {"g4", "ternary_g_nd"}, + {"g5", "ternary_g_nd"}, + }; + kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary(); + for (auto [name, func] : kernel_types) { + std::string template_def; + if (name == "g4" || name == "g5") { + int dim = std::stoi(name.substr(1)); + template_def = get_template_definition( + name + "_" + lib_name, func, get_type_string(type), op, dim); + } else { + template_def = get_template_definition( + name + "_" + lib_name, func, get_type_string(type), op); + } + kernel_source << template_def; + } lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 936ebca24..40f548868 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -15,24 +15,28 @@ MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, - const array& out); + Dtype out_type, + const std::string op); MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out); + Dtype in_type, + Dtype out_type, + const std::string op); MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out); + Dtype in_type, + Dtype out_type, + const std::string op); MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, - const array& out); + Dtype type, + const std::string op); MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index f85f96a53..11e4f7a6d 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -4,148 +4,91 @@ #include // clang-format off +#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary(name, itype, otype, op, bopt) \ - template \ - [[host_name(name)]] [[kernel]] void binary_##bopt( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - uint index [[thread_position_in_grid]]); +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ + instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \ + instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5) -#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ - template [[host_name("g" #dims name)]] [[kernel]] void \ - binary_g_nd( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); +#define instantiate_binary_integer(op) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_all(op, int64, int64_t, int64_t) -#define instantiate_binary_g_nd(name, itype, otype, op) \ - template [[host_name("g1" name)]] [[kernel]] void \ - binary_g_nd1( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t& a_stride, \ - constant const size_t& b_stride, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g2" name)]] [[kernel]] void \ - binary_g_nd2( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name("g3" name)]] [[kernel]] void \ - binary_g_nd3( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_binary_g_dim(name, itype, otype, op, 4) \ - instantiate_binary_g_dim(name, itype, otype, op, 5) +#define instantiate_binary_float(op) \ + instantiate_binary_all(op, float16, half, half) \ + instantiate_binary_all(op, float32, float, float) \ + instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_g(name, itype, otype, op) \ - template [[host_name("gn" name)]] [[kernel]] void binary_g( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_integer(op) \ + instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_float(op) -#define instantiate_binary_all(name, tname, itype, otype, op) \ - instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ - instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ - instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ - instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ - instantiate_binary_g(#name #tname, itype, otype, op) \ - instantiate_binary_g_nd(#name #tname, itype, otype, op) +#define instantiate_binary_types_bool(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, bool) \ + instantiate_binary_all(op, uint16, uint16_t, bool) \ + instantiate_binary_all(op, uint32, uint32_t, bool) \ + instantiate_binary_all(op, uint64, uint64_t, bool) \ + instantiate_binary_all(op, int8, int8_t, bool) \ + instantiate_binary_all(op, int16, int16_t, bool) \ + instantiate_binary_all(op, int32, int32_t, bool) \ + instantiate_binary_all(op, int64, int64_t, bool) \ + instantiate_binary_all(op, float16, half, bool) \ + instantiate_binary_all(op, float32, float, bool) \ + instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \ + instantiate_binary_all(op, complex64, complex64_t, bool) -#define instantiate_binary_integer(name, op) \ - instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ - instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ - instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ - instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ - instantiate_binary_all(name, int8, int8_t, int8_t, op) \ - instantiate_binary_all(name, int16, int16_t, int16_t, op) \ - instantiate_binary_all(name, int32, int32_t, int32_t, op) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op) - -#define instantiate_binary_float(name, op) \ - instantiate_binary_all(name, float16, half, half, op) \ - instantiate_binary_all(name, float32, float, float, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) - -#define instantiate_binary_types(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ - instantiate_binary_integer(name, op) \ - instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ - instantiate_binary_float(name, op) - -#define instantiate_binary_types_bool(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ - instantiate_binary_all(name, uint8, uint8_t, bool, op) \ - instantiate_binary_all(name, uint16, uint16_t, bool, op) \ - instantiate_binary_all(name, uint32, uint32_t, bool, op) \ - instantiate_binary_all(name, uint64, uint64_t, bool, op) \ - instantiate_binary_all(name, int8, int8_t, bool, op) \ - instantiate_binary_all(name, int16, int16_t, bool, op) \ - instantiate_binary_all(name, int32, int32_t, bool, op) \ - instantiate_binary_all(name, int64, int64_t, bool, op) \ - instantiate_binary_all(name, float16, half, bool, op) \ - instantiate_binary_all(name, float32, float, bool, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ - instantiate_binary_all(name, complex64, complex64_t, bool, op) - -instantiate_binary_types(add, Add) -instantiate_binary_types(div, Divide) -instantiate_binary_types_bool(eq, Equal) -instantiate_binary_types_bool(ge, Greater) -instantiate_binary_types_bool(geq, GreaterEqual) -instantiate_binary_types_bool(le, Less) -instantiate_binary_types_bool(leq, LessEqual) -instantiate_binary_types_bool(neq, NotEqual) -instantiate_binary_float(lae, LogAddExp) -instantiate_binary_types(max, Maximum) -instantiate_binary_types(min, Minimum) -instantiate_binary_types(mul, Multiply) -instantiate_binary_types(sub, Subtract) -instantiate_binary_types(pow, Power) -instantiate_binary_types(rem, Remainder) -instantiate_binary_float(arctan2, ArcTan2) +instantiate_binary_types(Add) +instantiate_binary_types(Divide) +instantiate_binary_types_bool(Equal) +instantiate_binary_types_bool(Greater) +instantiate_binary_types_bool(GreaterEqual) +instantiate_binary_types_bool(Less) +instantiate_binary_types_bool(LessEqual) +instantiate_binary_types_bool(NotEqual) +instantiate_binary_float(LogAddExp) +instantiate_binary_types(Maximum) +instantiate_binary_types(Minimum) +instantiate_binary_types(Multiply) +instantiate_binary_types(Subtract) +instantiate_binary_types(Power) +instantiate_binary_types(Remainder) +instantiate_binary_float(ArcTan2) // NaNEqual only needed for floating point types with boolean output -instantiate_binary_all(naneq, float16, half, bool, NaNEqual) -instantiate_binary_all(naneq, float32, float, bool, NaNEqual) -instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) -instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) +instantiate_binary_all(NaNEqual, float16, half, bool) +instantiate_binary_all(NaNEqual, float32, float, bool) +instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool) +instantiate_binary_all(NaNEqual, complex64, complex64_t, bool) -instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) -instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) +instantiate_binary_all(LogicalOr, bool_, bool, bool) +instantiate_binary_all(LogicalAnd, bool_, bool, bool) // Bitwise ops only need integer types and bool (except for l/r shift) -instantiate_binary_integer(bitwise_and, BitwiseAnd) -instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd) -instantiate_binary_integer(bitwise_or, BitwiseOr) -instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr) -instantiate_binary_integer(bitwise_xor, BitwiseXor) -instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor) -instantiate_binary_integer(left_shift, LeftShift) -instantiate_binary_integer(right_shift, RightShift) // clang-format on +instantiate_binary_integer(BitwiseAnd) +instantiate_binary_all(BitwiseAnd, bool_, bool, bool) +instantiate_binary_integer(BitwiseOr) +instantiate_binary_all(BitwiseOr, bool_, bool, bool) +instantiate_binary_integer(BitwiseXor) +instantiate_binary_all(BitwiseXor, bool_, bool, bool) +instantiate_binary_integer(LeftShift) +instantiate_binary_integer(RightShift) // clang-format on diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 0f63227d9..275f6a0d8 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,99 +7,34 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary(name, itype, otype, op, bopt) \ - template [[host_name(name)]] [[kernel]] void \ - binary_##bopt( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - uint index [[thread_position_in_grid]]); +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ + instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \ + instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5) -#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ - template [[host_name("g" #dims name)]] [[kernel]] void \ - binary_g_nd( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); +#define instantiate_binary_float(op) \ + instantiate_binary_all(op, float16, half, half) \ + instantiate_binary_all(op, float32, float, float) \ + instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t) -#define instantiate_binary_g_nd(name, itype, otype, op) \ - template [[host_name("g1" name)]] [[kernel]] void \ - binary_g_nd1( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t& a_stride, \ - constant const size_t& b_stride, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g2" name)]] [[kernel]] void \ - binary_g_nd2( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name("g3" name)]] [[kernel]] void \ - binary_g_nd3( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_binary_g_dim(name, itype, otype, op, 4) \ - instantiate_binary_g_dim(name, itype, otype, op, 5) +#define instantiate_binary_types(op) \ + instantiate_binary_all(op, bool_, bool, bool) \ + instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ + instantiate_binary_all(op, uint16, uint16_t, uint16_t) \ + instantiate_binary_all(op, uint32, uint32_t, uint32_t) \ + instantiate_binary_all(op, uint64, uint64_t, uint64_t) \ + instantiate_binary_all(op, int8, int8_t, int8_t) \ + instantiate_binary_all(op, int16, int16_t, int16_t) \ + instantiate_binary_all(op, int32, int32_t, int32_t) \ + instantiate_binary_all(op, int64, int64_t, int64_t) \ + instantiate_binary_all(op, complex64, complex64_t, complex64_t) \ + instantiate_binary_float(op) -#define instantiate_binary_g(name, itype, otype, op) \ - template [[host_name("gn" name)]] [[kernel]] void \ - binary_g( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - device otype* d, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); - -#define instantiate_binary_all(name, tname, itype, otype, op) \ - instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ - instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ - instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ - instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ - instantiate_binary_g(#name #tname, itype, otype, op) \ - instantiate_binary_g_nd(#name #tname, itype, otype, op) - -#define instantiate_binary_float(name, op) \ - instantiate_binary_all(name, float16, half, half, op) \ - instantiate_binary_all(name, float32, float, float, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) - -#define instantiate_binary_types(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ - instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ - instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ - instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ - instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ - instantiate_binary_all(name, int8, int8_t, int8_t, op) \ - instantiate_binary_all(name, int16, int16_t, int16_t, op) \ - instantiate_binary_all(name, int32, int32_t, int32_t, op) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op) \ - instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ - instantiate_binary_float(name, op) - -instantiate_binary_types(divmod, DivMod) // clang-format on +instantiate_binary_types(DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index b9392eb56..97bfdf81c 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -9,96 +9,28 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_v(name, type, op) \ - template [[host_name("v_" name)]] [[kernel]] void ternary_v( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - uint index [[thread_position_in_grid]]); +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ + instantiate_kernel("g_" #op #tname, ternary_g, type, op) \ + instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ + instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ + instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \ + instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \ + instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5) -#define instantiate_ternary_g(name, type, op) \ - template [[host_name("g_" name)]] [[kernel]] void ternary_g( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const size_t* c_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); +#define instantiate_ternary_types(op) \ + instantiate_ternary_all(op, bool_, bool) \ + instantiate_ternary_all(op, uint8, uint8_t) \ + instantiate_ternary_all(op, uint16, uint16_t) \ + instantiate_ternary_all(op, uint32, uint32_t) \ + instantiate_ternary_all(op, uint64, uint64_t) \ + instantiate_ternary_all(op, int8, int8_t) \ + instantiate_ternary_all(op, int16, int16_t) \ + instantiate_ternary_all(op, int32, int32_t) \ + instantiate_ternary_all(op, int64, int64_t) \ + instantiate_ternary_all(op, float16, half) \ + instantiate_ternary_all(op, float32, float) \ + instantiate_ternary_all(op, bfloat16, bfloat16_t) \ + instantiate_ternary_all(op, complex64, complex64_t) // clang-format on -#define instantiate_ternary_g_dim(name, type, op, dims) \ - template [[host_name("g" #dims "_" name )]] [[kernel]] void \ - ternary_g_nd( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const int shape[dims], \ - constant const size_t a_strides[dims], \ - constant const size_t b_strides[dims], \ - constant const size_t c_strides[dims], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); - -#define instantiate_ternary_g_nd(name, type, op) \ - template [[host_name("g1_" name)]] [[kernel]] void \ - ternary_g_nd1( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t& a_strides, \ - constant const size_t& b_strides, \ - constant const size_t& c_strides, \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g2_" name)]] [[kernel]] void \ - ternary_g_nd2( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t a_strides[2], \ - constant const size_t b_strides[2], \ - constant const size_t c_strides[2], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name("g3_" name)]] [[kernel]] void \ - ternary_g_nd3( \ - device const bool* a, \ - device const type* b, \ - device const type* c, \ - device type* d, \ - constant const size_t a_strides[3], \ - constant const size_t b_strides[3], \ - constant const size_t c_strides[3], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - instantiate_ternary_g_dim(name, type, op, 4) \ - instantiate_ternary_g_dim(name, type, op, 5) - -#define instantiate_ternary_all(name, tname, type, op) \ - instantiate_ternary_v(#name #tname, type, op) \ - instantiate_ternary_g(#name #tname, type, op) \ - instantiate_ternary_g_nd(#name #tname, type, op) - -#define instantiate_ternary_types(name, op) \ - instantiate_ternary_all(name, bool_, bool, op) \ - instantiate_ternary_all(name, uint8, uint8_t, op) \ - instantiate_ternary_all(name, uint16, uint16_t, op) \ - instantiate_ternary_all(name, uint32, uint32_t, op) \ - instantiate_ternary_all(name, uint64, uint64_t, op) \ - instantiate_ternary_all(name, int8, int8_t, op) \ - instantiate_ternary_all(name, int16, int16_t, op) \ - instantiate_ternary_all(name, int32, int32_t, op) \ - instantiate_ternary_all(name, int64, int64_t, op) \ - instantiate_ternary_all(name, float16, half, op) \ - instantiate_ternary_all(name, float32, float, op) \ - instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ - instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on - -instantiate_ternary_types(select, Select) +instantiate_ternary_types(Select) diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 002a5a24f..5ad57a15d 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,83 +5,68 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_v(name, type, op) \ - template [[host_name(name)]] [[kernel]] void unary_v( \ - device const type* in, \ - device type* out, \ - uint index [[thread_position_in_grid]]); +#define instantiate_unary_all(op, tname, type) \ + instantiate_kernel("v" #op #tname, unary_v, type, op) \ + instantiate_kernel("g" #op #tname, unary_g, type, op) -#define instantiate_unary_g(name, type, op) \ - template [[host_name(name)]] [[kernel]] void unary_g( \ - device const type* in, \ - device type* out, \ - device const int* in_shape, \ - device const size_t* in_strides, \ - device const int& ndim, \ - uint index [[thread_position_in_grid]]); +#define instantiate_unary_float(op) \ + instantiate_unary_all(op, float16, half) \ + instantiate_unary_all(op, float32, float) \ + instantiate_unary_all(op, bfloat16, bfloat16_t) -#define instantiate_unary_all(name, tname, type, op) \ - instantiate_unary_v("v" #name #tname, type, op) \ - instantiate_unary_g("g" #name #tname, type, op) +#define instantiate_unary_types(op) \ + instantiate_unary_all(op, bool_, bool) \ + instantiate_unary_all(op, uint8, uint8_t) \ + instantiate_unary_all(op, uint16, uint16_t) \ + instantiate_unary_all(op, uint32, uint32_t) \ + instantiate_unary_all(op, uint64, uint64_t) \ + instantiate_unary_all(op, int8, int8_t) \ + instantiate_unary_all(op, int16, int16_t) \ + instantiate_unary_all(op, int32, int32_t) \ + instantiate_unary_all(op, int64, int64_t) \ + instantiate_unary_float(op) -#define instantiate_unary_float(name, op) \ - instantiate_unary_all(name, float16, half, op) \ - instantiate_unary_all(name, float32, float, op) \ - instantiate_unary_all(name, bfloat16, bfloat16_t, op) +instantiate_unary_types(Abs) +instantiate_unary_float(ArcCos) +instantiate_unary_float(ArcCosh) +instantiate_unary_float(ArcSin) +instantiate_unary_float(ArcSinh) +instantiate_unary_float(ArcTan) +instantiate_unary_float(ArcTanh) +instantiate_unary_types(Ceil) +instantiate_unary_float(Cos) +instantiate_unary_float(Cosh) +instantiate_unary_float(Exp) +instantiate_unary_float(Expm1) +instantiate_unary_types(Floor) +instantiate_unary_float(Log) +instantiate_unary_float(Log2) +instantiate_unary_float(Log10) +instantiate_unary_float(Log1p) +instantiate_unary_types(Negative) +instantiate_unary_float(Sigmoid) +instantiate_unary_float(Erf) +instantiate_unary_float(ErfInv) +instantiate_unary_types(Sign) +instantiate_unary_float(Sin) +instantiate_unary_float(Sinh) +instantiate_unary_types(Square) +instantiate_unary_float(Sqrt) +instantiate_unary_float(Rsqrt) +instantiate_unary_float(Tan) +instantiate_unary_float(Tanh) +instantiate_unary_float(Round) -#define instantiate_unary_types(name, op) \ - instantiate_unary_all(name, bool_, bool, op) \ - instantiate_unary_all(name, uint8, uint8_t, op) \ - instantiate_unary_all(name, uint16, uint16_t, op) \ - instantiate_unary_all(name, uint32, uint32_t, op) \ - instantiate_unary_all(name, uint64, uint64_t, op) \ - instantiate_unary_all(name, int8, int8_t, op) \ - instantiate_unary_all(name, int16, int16_t, op) \ - instantiate_unary_all(name, int32, int32_t, op) \ - instantiate_unary_all(name, int64, int64_t, op) \ - instantiate_unary_float(name, op) +instantiate_unary_all(Abs, complex64, complex64_t) +instantiate_unary_all(Conjugate, complex64, complex64_t) +instantiate_unary_all(Cos, complex64, complex64_t) +instantiate_unary_all(Cosh, complex64, complex64_t) +instantiate_unary_all(Exp, complex64, complex64_t) +instantiate_unary_all(Negative, complex64, complex64_t) +instantiate_unary_all(Sin, complex64, complex64_t) +instantiate_unary_all(Sinh, complex64, complex64_t) +instantiate_unary_all(Tan, complex64, complex64_t) +instantiate_unary_all(Tanh, complex64, complex64_t) +instantiate_unary_all(Round, complex64, complex64_t) -instantiate_unary_types(abs, Abs) -instantiate_unary_float(arccos, ArcCos) -instantiate_unary_float(arccosh, ArcCosh) -instantiate_unary_float(arcsin, ArcSin) -instantiate_unary_float(arcsinh, ArcSinh) -instantiate_unary_float(arctan, ArcTan) -instantiate_unary_float(arctanh, ArcTanh) -instantiate_unary_types(ceil, Ceil) -instantiate_unary_float(cos, Cos) -instantiate_unary_float(cosh, Cosh) -instantiate_unary_float(exp, Exp) -instantiate_unary_float(expm1, Expm1) -instantiate_unary_types(floor, Floor) -instantiate_unary_float(log, Log) -instantiate_unary_float(log2, Log2) -instantiate_unary_float(log10, Log10) -instantiate_unary_float(log1p, Log1p) -instantiate_unary_types(neg, Negative) -instantiate_unary_float(sigmoid, Sigmoid) -instantiate_unary_float(erf, Erf) -instantiate_unary_float(erfinv, ErfInv) -instantiate_unary_types(sign, Sign) -instantiate_unary_float(sin, Sin) -instantiate_unary_float(sinh, Sinh) -instantiate_unary_types(square, Square) -instantiate_unary_float(sqrt, Sqrt) -instantiate_unary_float(rsqrt, Rsqrt) -instantiate_unary_float(tan, Tan) -instantiate_unary_float(tanh, Tanh) -instantiate_unary_float(round, Round) - -instantiate_unary_all(abs, complex64, complex64_t, Abs) -instantiate_unary_all(conj, complex64, complex64_t, Conjugate) -instantiate_unary_all(cos, complex64, complex64_t, Cos) -instantiate_unary_all(cosh, complex64, complex64_t, Cosh) -instantiate_unary_all(exp, complex64, complex64_t, Exp) -instantiate_unary_all(neg, complex64, complex64_t, Negative) -instantiate_unary_all(sin, complex64, complex64_t, Sin) -instantiate_unary_all(sinh, complex64, complex64_t, Sinh) -instantiate_unary_all(tan, complex64, complex64_t, Tan) -instantiate_unary_all(tanh, complex64, complex64_t, Tanh) -instantiate_unary_all(round, complex64, complex64_t, Round) - -instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on +instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 28ab672e5..ce8a3f582 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -15,30 +15,34 @@ MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, - const array&) { + Dtype, + const std::string) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, - const array&, - const array&) { + Dtype, + Dtype, + const std::string) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, - const array&, - const array&) { + Dtype, + Dtype, + const std::string) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, - const array&) { + Dtype, + const std::string) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 51c7b9883..1f12ecf39 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -49,7 +49,7 @@ void ternary_op_gpu_inplace( auto& d = metal::device(s.device); - auto kernel = get_ternary_kernel(d, kernel_name, out); + auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); @@ -122,7 +122,7 @@ void ternary_op_gpu( } void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op_gpu(inputs, out, "select"); + ternary_op_gpu(inputs, out, get_primitive_string(this)); } } // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index f4d9721fa..2ac01c490 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -5,6 +5,11 @@ #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + unary_op_gpu(inputs, out, get_primitive_string(this)); \ + } + namespace mlx::core { void unary_op_gpu_inplace( @@ -21,7 +26,7 @@ void unary_op_gpu_inplace( auto& d = metal::device(s.device); std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); - auto kernel = get_unary_kernel(d, kernel_name, out); + auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); size_t nthreads = contig ? in.data_size() : in.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); @@ -77,148 +82,57 @@ void unary_op_gpu( unary_op_gpu(inputs, out, op, s); } -void Abs::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "abs"); -} - -void ArcCos::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arccos"); -} - -void ArcCosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arccosh"); -} - -void ArcSin::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arcsin"); -} - -void ArcSinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arcsinh"); -} - -void ArcTan::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arctan"); -} - -void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "arctanh"); -} - -void Conjugate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == complex64) { - unary_op_gpu(inputs, out, "conj"); - } else { - throw std::invalid_argument( - "[conjugate] conjugate must be called on complex input."); - } -} - -void Cos::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "cos"); -} - -void Cosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "cosh"); -} - -void Erf::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "erf"); -} - -void ErfInv::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "erfinv"); -} - -void Exp::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "exp"); -} - -void Expm1::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "expm1"); -} +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Floor) +UNARY_GPU(Ceil) +UNARY_GPU(Negative) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Sqrt) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: - unary_op_gpu(inputs, out, "log"); + unary_op_gpu(inputs, out, get_primitive_string(this)); break; case Base::two: - unary_op_gpu(inputs, out, "log2"); + unary_op_gpu(inputs, out, get_primitive_string(this)); break; case Base::ten: - unary_op_gpu(inputs, out, "log10"); + unary_op_gpu(inputs, out, get_primitive_string(this)); break; } } -void Log1p::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "log1p"); -} - -void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "lnot"); -} - -void Floor::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "floor"); -} - -void Ceil::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "ceil"); -} - -void Negative::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "neg"); -} - void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_op_gpu(inputs, out, "round"); + unary_op_gpu(inputs, out, get_primitive_string(this)); } else { // No-op integer types out.copy_shared_buffer(in); } } -void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "sigmoid"); -} - -void Sign::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "sign"); -} - -void Sin::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "sin"); -} - -void Sinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "sinh"); -} - -void Square::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "square"); -} - -void Sqrt::eval_gpu(const std::vector& inputs, array& out) { - if (recip_) { - unary_op_gpu(inputs, out, "rsqrt"); - } else { - unary_op_gpu(inputs, out, "sqrt"); - } -} - -void Tan::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "tan"); -} - -void Tanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op_gpu(inputs, out, "tanh"); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index ee2b4a07b..6ecb0a095 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -141,6 +141,12 @@ int next_power_of_2(int n) { return pow(2, std::ceil(std::log2(n))); } +std::string get_primitive_string(Primitive* primitive) { + std::ostringstream op_t; + primitive->print(op_t); + return op_t.str(); +} + } // namespace } // namespace mlx::core