From a4a4b46b8d4137be5f439e7fe242b4bfed1543e5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Jun 2025 11:08:22 -0700 Subject: [PATCH] fix jit --- mlx/backend/metal/jit_kernels.cpp | 43 +++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index b6733c1b7..15e21af6c 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -42,8 +42,10 @@ MTL::ComputePipelineState* get_unary_kernel( concatenate(kernel_source, metal::unary_ops(), metal::unary()); kernel_source += get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op, 1); - kernel_source += - get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); + if (get_work_per_thread(in_type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "unary_v", in_t, out_t, op); + } kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( @@ -61,11 +63,8 @@ void append_binary_kernels( Dtype out_type, const std::string op, std::string& kernel_source) { - const std::array, 10> kernel_types = {{ + const std::array, 7> kernel_types = {{ {"ss", "binary_ss"}, - {"vsn", "binary_vs"}, - {"svn", "binary_sv"}, - {"vvn", "binary_vv"}, {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, @@ -85,7 +84,17 @@ void append_binary_kernels( kernel_source += get_template_definition( "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); kernel_source += get_template_definition( - "sv_" + lib_name, "binary_sv", in_t, out_t, op, 1); + "vv_" + lib_name, "binary_vv", in_t, out_t, op, 1); + + if (get_work_per_thread(in_type) > 1) { + kernel_source += get_template_definition( + "vsn_" + lib_name, "binary_vs", in_t, out_t, op); + kernel_source += get_template_definition( + "svn_" + lib_name, "binary_sv", in_t, out_t, op); + kernel_source += get_template_definition( + "vvn_" + lib_name, "binary_vv", in_t, out_t, op); + } + kernel_source += get_template_definition( "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( @@ -141,8 +150,7 @@ MTL::ComputePipelineState* get_ternary_kernel( auto t_str = get_type_string(type); std::string kernel_source = metal::utils(); concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); - const std::array, 5> kernel_types = {{ - {"vn", "ternary_v"}, + const std::array, 4> kernel_types = {{ {"v2", "ternary_v2"}, {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, @@ -152,6 +160,11 @@ MTL::ComputePipelineState* get_ternary_kernel( kernel_source += get_template_definition(name + "_" + lib_name, func, t_str, op); } + if (get_work_per_thread(type) > 1) { + kernel_source += + get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); + } + kernel_source += get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); kernel_source += get_template_definition( @@ -182,18 +195,20 @@ MTL::ComputePipelineState* get_copy_kernel( auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition( "s_" + lib_name, "copy_s", in_type, out_type, 1); - kernel_source += - get_template_definition("sn_" + lib_name, "copy_s", in_type, out_type); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); kernel_source += get_template_definition( "v_" + lib_name, "copy_v", in_type, out_type, 1); - kernel_source += - get_template_definition("vn_" + lib_name, "copy_v", in_type, out_type); - kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); + if (get_work_per_thread(out.dtype()) > 1) { + kernel_source += get_template_definition( + "sn_" + lib_name, "copy_s", in_type, out_type); + kernel_source += get_template_definition( + "vn_" + lib_name, "copy_v", in_type, out_type); + } + kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition(