diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 29564a707..58425d949 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -433,19 +433,18 @@ struct PyCompiledFun { auto d = nb::cast(obj); constants.push_back(dict_identifier); for (auto item : d) { - auto r = item.first.attr("__hash__"); - constants.push_back(*reinterpret_cast(&r)); + auto r = item.first.attr("__hash__")(); + constants.push_back(nb::cast(r)); recurse(item.second); } } else if (nb::isinstance(obj)) { inputs.push_back(nb::cast(obj)); constants.push_back(array_identifier); } else if (nb::isinstance(obj)) { - auto r = obj.attr("__hash__"); - constants.push_back(*reinterpret_cast(&r)); + auto r = obj.attr("__hash__")(); + constants.push_back(nb::cast(r)); } else if (nb::isinstance(obj)) { - auto r = nb::cast(obj); - constants.push_back(*reinterpret_cast(&r)); + constants.push_back(nb::cast(obj)); } else if (nb::isinstance(obj)) { auto r = nb::cast(obj); constants.push_back(*reinterpret_cast(&r)); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 8e496ab06..feb8e6da6 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -579,6 +579,27 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(counter[0], 2) + y = 1.0 + + @mx.compile + def fun(x, constant): + return x + y + + constant1 = "abc" + out = fun(mx.array(0.0), constant1) + self.assertEqual(out, mx.array(1.0)) + + # new object, same value, no recompilation + y = 2.0 + constant2 = "abc".encode("utf-8").decode("utf-8") + out = fun(mx.array(0.0), constant2) + self.assertEqual(out, mx.array(1.0)) + + # same object, new value, recompilation + constant2 = "xyz" + out = fun(mx.array(0.0), constant2) + self.assertEqual(out, mx.array(2.0)) + def test_compile_inf(self): @mx.compile def fun(x):