From 92d7cb71f8dbcee592ad88de30e955e15b99492b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 18 Oct 2024 11:06:40 -0700 Subject: [PATCH] Fix compile (#1501) * fix compile * fix space --- mlx/backend/common/compiled_cpu.cpp | 4 ++-- mlx/backend/metal/compiled.cpp | 21 +++++++++++++-------- mlx/backend/metal/utils.cpp | 4 ++-- mlx/backend/metal/utils.h | 4 ++-- python/tests/test_compile.py | 17 +++++++++++++++++ 5 files changed, 36 insertions(+), 14 deletions(-) diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index 5eb904b45..0923398c7 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -103,8 +103,8 @@ void* compile( source_file.close(); std::ostringstream build_command; - build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared " - << source_file_path << " -o " << shared_lib_path; + build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '" + << source_file_path << "' -o '" << shared_lib_path << "'"; std::string build_command_str = build_command.str(); auto return_code = system(build_command_str.c_str()); if (return_code) { diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 90bc34ee1..add0592f2 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -93,10 +93,10 @@ inline void build_kernel( // a third grid dimension os << " size_t index = pos.x + grid.x * size_t(pos.y);\n"; } else if (work_per_thread > 1) { - os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n" + os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n" << " int xshape = output_shape[" << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n" - << " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; + << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; } else { os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n"; } @@ -141,11 +141,11 @@ inline void build_kernel( << "in_strides + " << offset << ");\n"; } else if (!dynamic_dims) { int offset = i * ndim; - os << " size_t index_" << xname << " = N * pos.x * in_strides[" + os << " size_t index_" << xname << " = N_ * pos.x * in_strides[" << offset + ndim - 1 << "]" << " + pos.y * in_strides[" << offset + ndim - 2 << "];\n"; } else { - os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * " + os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * " << i << " + ndim - 1]" << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n"; } @@ -172,7 +172,7 @@ inline void build_kernel( // Open per-thread loop if (work_per_thread > 1) { - os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n"; + os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } // Read non-contiguous inputs into tmps @@ -434,10 +434,15 @@ void Compiled::eval_gpu( int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + int pow2; + if (thread_group_size == 1024) { + pow2 = 10; + } else if (thread_group_size > 512) { + pow2 = 9; + } else { + throw std::runtime_error("[Metal::compiled] Must use > 512 sized block"); } - auto group_dims = get_block_dims(dim0, dim1, rest); + auto group_dims = get_block_dims(dim0, dim1, rest, pow2); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatchThreads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index bbb80acbd..1242209a1 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -52,7 +52,7 @@ std::string type_to_name(const array& a) { return tname; } -MTL::Size get_block_dims(int dim0, int dim1, int dim2) { +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { int pows[3] = {0, 0, 0}; int sum = 0; while (true) { @@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { pows[2]++; sum++; } - if (sum == presum || sum == 10) { + if (sum == presum || sum == pow2) { break; } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index af85281a9..ad49c52a1 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -30,8 +30,8 @@ std::string type_to_name(const array& a); // Compute the thread block dimensions which fit the given // input dimensions. // - The thread block dimensions will be powers of two -// - The thread block size will be less than 1024 -MTL::Size get_block_dims(int dim0, int dim1, int dim2); +// - The thread block size will be less than 2^pow2 +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); // Computes a 2D grid where each element is < UINT_MAX // Assumes: diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 897c7b486..81e6ccedf 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -772,6 +772,23 @@ class TestCompile(mlx_tests.MLXTestCase): print((out - expected).abs().max()) self.assertTrue(mx.allclose(out, expected)) + def test_compile_many_inputs(self): + inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)] + inputs[0] = inputs[0].T + + @mx.compile + def fun(*inputs): + x = inputs[0] + for y in inputs[1:10]: + x = x + y + a = inputs[10] + for b in inputs[11:]: + a = a + b + return x + a + + out = fun(*inputs) + self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) + if __name__ == "__main__": unittest.main()