diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 31bd7903b..a9cc267e1 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -176,8 +176,13 @@ MTL::ComputePipelineState* get_copy_kernel( auto out_type = get_type_string(out.dtype()); kernel_source += get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + kernel_source += + get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); kernel_source += get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + kernel_source += + get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); + kernel_source += get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition(