From 90532b1f37c12c1ec2434ecb7334dd5d33ee1eb4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 20 Jan 2025 21:07:10 -0800 Subject: [PATCH] recompile when shapeless is different (#1776) --- mlx/compile.cpp | 9 +++++++-- python/tests/test_compile.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 45fc1c1aa..2c3d4484a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -197,8 +197,10 @@ std::uintptr_t get_function_address(const std::function& fun) { class CompilerCache { public: struct CacheEntry { - CacheEntry(Stream stream) : stream(stream) {}; + CacheEntry(Stream stream, bool shapeless) + : stream(stream), shapeless(shapeless) {}; Stream stream; + bool shapeless; std::vector inputs; std::vector outputs; std::vector tape; @@ -245,6 +247,9 @@ class CompilerCache { if (entry.stream != stream) { continue; } + if (entry.shapeless != shapeless) { + continue; + } // Check the inputs match and return if so if (has_same_shape_and_dtype(inputs, entry.inputs) && @@ -253,7 +258,7 @@ class CompilerCache { } } // Otherwise append a new cache entry - entries.push_back(CacheEntry{stream}); + entries.push_back(CacheEntry{stream, shapeless}); return entries.back(); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 1dbd052c1..1974b9a23 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -675,6 +675,10 @@ class TestCompile(mlx_tests.MLXTestCase): def mean(x): return mx.mean(x, keepdims=True) + cfun = mx.compile(mean) + out = cfun(mx.ones((5, 5))) + self.assertTrue(mx.allclose(out, mx.array(1.0))) + cmean = mx.compile(mean, shapeless=True) x = mx.ones(2)