diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 45fc1c1aa..2c3d4484a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -197,8 +197,10 @@ std::uintptr_t get_function_address(const std::function& fun) { class CompilerCache { public: struct CacheEntry { - CacheEntry(Stream stream) : stream(stream) {}; + CacheEntry(Stream stream, bool shapeless) + : stream(stream), shapeless(shapeless) {}; Stream stream; + bool shapeless; std::vector inputs; std::vector outputs; std::vector tape; @@ -245,6 +247,9 @@ class CompilerCache { if (entry.stream != stream) { continue; } + if (entry.shapeless != shapeless) { + continue; + } // Check the inputs match and return if so if (has_same_shape_and_dtype(inputs, entry.inputs) && @@ -253,7 +258,7 @@ class CompilerCache { } } // Otherwise append a new cache entry - entries.push_back(CacheEntry{stream}); + entries.push_back(CacheEntry{stream, shapeless}); return entries.back(); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 1dbd052c1..1974b9a23 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -675,6 +675,10 @@ class TestCompile(mlx_tests.MLXTestCase): def mean(x): return mx.mean(x, keepdims=True) + cfun = mx.compile(mean) + out = cfun(mx.ones((5, 5))) + self.assertTrue(mx.allclose(out, mx.array(1.0))) + cmean = mx.compile(mean, shapeless=True) x = mx.ones(2)