From 881615b07269f41a8ef57e0063e343d875e522ab Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 12:45:38 -0700 Subject: [PATCH] Faster metal compiled kernels + some fixes (#1486) * bump mac tests to use py39 * work per thread for compiled kernels * fixe for large arrays * fix --- mlx/backend/metal/binary.cpp | 3 +- mlx/backend/metal/compiled.cpp | 179 +++++++++++++-------- mlx/backend/metal/copy.cpp | 2 +- mlx/backend/metal/jit_kernels.cpp | 52 +++--- mlx/backend/metal/kernels/binary.metal | 1 - mlx/backend/metal/kernels/binary_two.metal | 1 - mlx/backend/metal/kernels/copy.metal | 2 - mlx/backend/metal/kernels/ternary.metal | 3 +- mlx/backend/metal/kernels/unary.metal | 3 +- mlx/backend/metal/ternary.cpp | 3 +- mlx/backend/metal/unary.cpp | 2 +- python/tests/test_compile.py | 14 ++ 12 files changed, 157 insertions(+), 108 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 101810628..c87f98272 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -84,8 +84,7 @@ void binary_op_gpu_inplace( bool use_2d = out.data_size() > UINT32_MAX; auto ndim = shape.size(); - int work_per_thread = - (bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1; + int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1; std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index eb9e54f8c..90bc34ee1 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -13,6 +13,8 @@ namespace mlx::core { +constexpr int WORK_PER_THREAD = 4; + inline void build_kernel( std::ostream& os, const std::string& kernel_name, @@ -23,7 +25,8 @@ inline void build_kernel( bool contiguous, int ndim, bool dynamic_dims, - bool use_big_index = false) { + bool use_big_index = false, + int work_per_thread = 1) { // All outputs should have the exact same shape and will be row contiguous auto output_shape = outputs[0].shape(); auto output_strides = outputs[0].strides(); @@ -38,8 +41,8 @@ inline void build_kernel( int cnt = 0; // Start the kernel - os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl - << "[[kernel]] void " << kernel_name << "(" << std::endl; + os << "[[host_name(\"" << kernel_name << "\")]]\n" + << "[[kernel]] void " << kernel_name << "(\n"; // Add the input arguments for (auto& x : inputs) { @@ -53,11 +56,11 @@ inline void build_kernel( // Scalars and contiguous need no strides if (is_scalar(x) || contiguous) { os << " device const " << get_type_string(x.dtype()) << "* " << xname - << " [[buffer(" << cnt++ << ")]]," << std::endl; + << " [[buffer(" << cnt++ << ")]],\n"; } else { add_indices = true; os << " device const " << get_type_string(x.dtype()) << "* " << xname - << " [[buffer(" << cnt++ << ")]]," << std::endl; + << " [[buffer(" << cnt++ << ")]],\n"; } } @@ -69,58 +72,37 @@ inline void build_kernel( // Add the output arguments for (auto& x : outputs) { os << " device " << get_type_string(x.dtype()) << "* " - << namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl; + << namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n"; } // Add output strides and shape to extract the indices. if (!contiguous) { os << " constant const size_t* output_strides [[buffer(" << cnt++ - << ")]]," << std::endl - << " constant const int* output_shape [[buffer(" << cnt++ << ")]]," - << std::endl; + << ")]],\n" + << " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n"; } if (dynamic_dims) { - os << " constant const int& ndim [[buffer(" << cnt++ << ")]]," - << std::endl; + os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n"; } // The thread index in the whole grid - os << " uint3 pos [[thread_position_in_grid]]," << std::endl - << " uint3 grid [[threads_per_grid]]) {" << std::endl; + os << " uint3 pos [[thread_position_in_grid]],\n" + << " uint3 grid [[threads_per_grid]]) {\n"; + if (use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os << " size_t index = pos.x + grid.x * size_t(pos.y);"; + os << " size_t index = pos.x + grid.x * size_t(pos.y);\n"; + } else if (work_per_thread > 1) { + os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n" + << " int xshape = output_shape[" + << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n" + << " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; } else { - os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"; - } - os << std::endl; - - // Extract the indices per axis to individual uints if we have arrays that - // are broadcasted or transposed - if (add_indices) { - if (!dynamic_dims) { - if (ndim == 1) { - os << " uint index_0 = pos.x;" << std::endl; - } else if (ndim == 2) { - os << " uint index_0 = pos.y;" << std::endl - << " uint index_1 = pos.x;" << std::endl; - } else if (ndim == 3) { - os << " uint index_0 = pos.z;" << std::endl - << " uint index_1 = pos.y;" << std::endl - << " uint index_2 = pos.x;" << std::endl; - } else { - for (int i = 0; i < ndim - 2; i++) { - os << " uint index_" << i << " = (index / uint(output_strides[" << i - << "])) % output_shape[" << i << "];" << std::endl; - } - os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl - << " uint index_" << ndim - 1 << " = pos.x;" << std::endl; - } - } + os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n"; } - // Read the inputs in tmps - int nc_in_count = 0; + // Read constant / contiguous inputs in tmps + std::vector nc_inputs; for (int i = 0; i < inputs.size(); ++i) { auto& x = inputs[i]; auto& xname = namer.get_name(x); @@ -130,56 +112,117 @@ inline void build_kernel( os << " auto tmp_" << xname << " = static_cast<" << get_type_string(x.dtype()) << ">("; print_constant(os, x); - os << ");" << std::endl; + os << ");\n"; } else if (is_scalar(x)) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[0];" << std::endl; + << xname << "[0];\n"; } else if (contiguous) { os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[index];" << std::endl; - } else if (!dynamic_dims) { - int offset = nc_in_count * ndim; - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "["; - os << "index_0 * " << "in_strides[" << offset << "]"; - for (int i = 1; i < ndim; i++) { - os << " + index_" << i << " * " << "in_strides[" << offset + i << "]"; - } - os << "];" << std::endl; - nc_in_count++; + << xname << "[index];\n"; } else { - os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " - << xname << "[elem_to_loc(index, output_shape, in_strides + " - << nc_in_count * ndim << ", ndim)];" << std::endl; - nc_in_count++; + nc_inputs.push_back(x); } } + // Initialize the indices for non-contiguous inputs + for (int i = 0; i < nc_inputs.size(); ++i) { + auto& xname = namer.get_name(nc_inputs[i]); + if (ndim == 1) { + int offset = i * ndim; + os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, " + << "in_strides[" << offset << "]);\n"; + } else if (ndim == 2) { + int offset = i * ndim; + os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, " + << "in_strides + " << offset << ");\n"; + } else if (ndim == 3) { + int offset = i * ndim; + os << " size_t index_" << xname << " = elem_to_loc_3(pos, " + << "in_strides + " << offset << ");\n"; + } else if (!dynamic_dims) { + int offset = i * ndim; + os << " size_t index_" << xname << " = N * pos.x * in_strides[" + << offset + ndim - 1 << "]" + << " + pos.y * in_strides[" << offset + ndim - 2 << "];\n"; + } else { + os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * " + << i << " + ndim - 1]" + << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n"; + } + } + if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) { + os << " uint zpos = pos.z;\n"; + if (dynamic_dims) { + os << " for (int d = ndim - 3; d >= 0; --d) {\n"; + } else { + os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n"; + } + os << " uint l = zpos % output_shape[d];\n"; + for (int i = 0; i < nc_inputs.size(); ++i) { + auto& xname = namer.get_name(nc_inputs[i]); + os << " index_" << xname << " += "; + if (dynamic_dims) { + os << "l * in_strides[" << i << " * ndim + d];\n"; + } else { + os << "l * in_strides[" << i * ndim << " + d];\n"; + } + } + os << " zpos /= output_shape[d];\n }\n"; + } + + // Open per-thread loop + if (work_per_thread > 1) { + os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n"; + } + + // Read non-contiguous inputs into tmps + for (int i = 0; i < nc_inputs.size(); ++i) { + auto& x = nc_inputs[i]; + auto& xname = namer.get_name(x); + os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " + << xname << "[index_" << xname << "];\n"; + } + // Actually write the computation for (auto& x : tape) { os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) << " = "; if (is_static_cast(x.primitive())) { os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" - << namer.get_name(x.inputs()[0]) << ");" << std::endl; + << namer.get_name(x.inputs()[0]) << ");\n"; } else { x.primitive().print(os); os << "()("; for (int i = 0; i < x.inputs().size() - 1; i++) { os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; } - os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; + os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n"; } } // Write the outputs from tmps for (auto& x : outputs) { os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) - << ";" << std::endl; + << ";\n"; + } + // Increment indices and close per thread loop + if (work_per_thread > 1) { + for (int i = 0; i < nc_inputs.size(); ++i) { + auto& x = nc_inputs[i]; + auto& xname = namer.get_name(x); + if (!dynamic_dims) { + os << " index_" << xname << " += " + << "in_strides[" << i * ndim + ndim - 1 << "];\n"; + } else { + os << " index_" << xname << " += " + << "in_strides[" << i << " * ndim + ndim - 1];\n"; + } + } + os << " index++;\n }\n"; } // Finish the kernel - os << "}" << std::endl; + os << "}\n"; if (cnt > 31) { std::ostringstream msg; @@ -237,7 +280,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ false, /* ndim = */ i, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1); } build_kernel( kernel, @@ -248,7 +293,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ false, /* ndim = */ 0, - /* dynamic_dims = */ true); + /* dynamic_dims = */ true, + /* use_big_index = */ false, + /* work_per_thread = */ WORK_PER_THREAD); return kernel.str(); }); @@ -384,6 +431,8 @@ void Compiled::eval_gpu( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); + int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1; + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index a58e4c467..49a09483a 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -98,7 +98,7 @@ void copy_gpu_inplace( if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { kname << shape.size(); - } else if (shape[ndim - 1] >= 4) { + } else { work_per_thread = 4; kname << "n4"; } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index f9a998c5d..d0229f4cb 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -50,8 +50,6 @@ MTL::ComputePipelineState* get_unary_kernel( "v_" + lib_name, "unary_v", get_type_string(out_type), op); kernel_source << get_template_definition( "v2_" + lib_name, "unary_v2", get_type_string(out_type), op); - kernel_source << get_template_definition( - "g_" + lib_name, "unary_g", get_type_string(out_type), op); kernel_source << get_template_definition( "gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4); return kernel_source.str(); @@ -65,7 +63,7 @@ void add_binary_kernels( Dtype out_type, const std::string op, std::ostringstream& kernel_source) { - const std::array, 11> kernel_types = {{ + const std::array, 10> kernel_types = {{ {"ss", "binary_ss"}, {"vs", "binary_vs"}, {"sv", "binary_sv"}, @@ -76,7 +74,6 @@ void add_binary_kernels( {"g1", "binary_g_nd1"}, {"g2", "binary_g_nd2"}, {"g3", "binary_g_nd3"}, - {"gn", "binary_g"}, }}; for (auto& [name, func] : kernel_types) { std::string template_def; @@ -138,10 +135,9 @@ MTL::ComputePipelineState* get_ternary_kernel( std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; - const std::array, 6> kernel_types = {{ + const std::array, 5> kernel_types = {{ {"v", "ternary_v"}, {"v2", "ternary_v2"}, - {"g", "ternary_g"}, {"g1", "ternary_g_nd1"}, {"g2", "ternary_g_nd2"}, {"g3", "ternary_g_nd3"}, @@ -170,29 +166,27 @@ MTL::ComputePipelineState* get_copy_kernel( std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - kernel_source - << metal::utils() << metal::copy() - << get_template_definition("s_" + lib_name, "copy_s", in_type, out_type) - << get_template_definition("v_" + lib_name, "copy_v", in_type, out_type) - << get_template_definition( - "g1_" + lib_name, "copy_g_nd1", in_type, out_type) - << get_template_definition( - "g2_" + lib_name, "copy_g_nd2", in_type, out_type) - << get_template_definition( - "g3_" + lib_name, "copy_g_nd3", in_type, out_type) - << get_template_definition("g_" + lib_name, "copy_g", in_type, out_type) - << get_template_definition( - "gn4_" + lib_name, "copy_g", in_type, out_type, 4) - << get_template_definition( - "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) - << get_template_definition( - "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type) - << get_template_definition( - "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) - << get_template_definition( - "gg_" + lib_name, "copy_gg", in_type, out_type) - << get_template_definition( - "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); + kernel_source << metal::utils() << metal::copy() + << get_template_definition( + "s_" + lib_name, "copy_s", in_type, out_type) + << get_template_definition( + "v_" + lib_name, "copy_v", in_type, out_type) + << get_template_definition( + "g1_" + lib_name, "copy_g_nd1", in_type, out_type) + << get_template_definition( + "g2_" + lib_name, "copy_g_nd2", in_type, out_type) + << get_template_definition( + "g3_" + lib_name, "copy_g_nd3", in_type, out_type) + << get_template_definition( + "gn4_" + lib_name, "copy_g", in_type, out_type, 4) + << get_template_definition( + "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) + << get_template_definition( + "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type) + << get_template_definition( + "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) + << get_template_definition( + "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 5c437bd2a..db5de5b59 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -17,7 +17,6 @@ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index f062439ec..6f5e48c0e 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -15,7 +15,6 @@ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index a631183b7..7036d4b81 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -16,9 +16,7 @@ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ - instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \ - instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \ instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4) #define instantiate_copy_itype(itname, itype) \ diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 79e427775..dacafadef 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -12,11 +12,10 @@ #define instantiate_ternary_all(op, tname, type) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ - instantiate_kernel("g_" #op #tname, ternary_g, type, op) \ instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ - instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \ + instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 0c1b5d9e1..f301dce60 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -8,8 +8,7 @@ #define instantiate_unary_all(op, tname, type) \ instantiate_kernel("v_" #op #tname, unary_v, type, op) \ instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \ - instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \ - instantiate_kernel("g_" #op #tname, unary_g, type, op) + instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) #define instantiate_unary_float(op) \ instantiate_unary_all(op, float16, half) \ diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 357065bdc..d353dda5e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -38,8 +38,7 @@ void ternary_op_gpu_inplace( bool use_2d = out.data_size() > UINT_MAX; auto ndim = shape.size(); - int work_per_thread = - (topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1; + int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1; std::string kernel_name; { std::ostringstream kname; diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index eb4af03ec..a3903b89c 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -35,7 +35,7 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1; + int work_per_thread = !contig ? 4 : 1; size_t nthreads = contig ? in.data_size() : in.size(); bool use_2d = nthreads > UINT32_MAX; std::string kernel_name; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index af74cdaa5..897c7b486 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -758,6 +758,20 @@ class TestCompile(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): y = mx.compile(fn)(x) + def test_compile_dynamic_dims(self): + a = mx.random.uniform(shape=(2,) * 10) + b = mx.random.uniform(shape=(2,) * 10) + a = a.T + mx.eval(a, b) + + def fn(a, b): + return mx.abs(a + b) + + out = mx.compile(fn)(a, b) + expected = fn(a, b) + print((out - expected).abs().max()) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": unittest.main()