From 3ae6aabe9f28c0116ef85cd71d95348b1bd0d8d8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Sep 2024 14:54:31 -0700 Subject: [PATCH] throw for certain cases of non captured inputs in compile (#1401) --- mlx/compile.cpp | 14 ++++++++++---- mlx/fast.cpp | 2 +- python/src/fast.cpp | 26 +++++++++++++------------- python/tests/test_compile.py | 25 +++++++++++++++++++++++++ python/tests/test_fast.py | 21 +++++++++++++++++++++ 5 files changed, 70 insertions(+), 18 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 06d85be16..279fb0328 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -306,21 +306,27 @@ std::pair, std::vector> compile_trace( // Traverses the graph to build a tape and a map of array ids to their parents std::pair, ParentsMap> compile_dfs( const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, + const std::vector& original_inputs) { std::function recurse; std::vector tape; std::unordered_set input_set; + std::unordered_set original_input_set; std::unordered_map>> parents_map; for (int i = 0; i < inputs.size(); ++i) { - auto in = inputs[i]; - input_set.insert(in.id()); + input_set.insert(inputs[i].id()); + original_input_set.insert(original_inputs[i].id()); } // DFS the graph to build the tape, and log parents and scalars std::unordered_set cache; recurse = [&](const array& a) { auto id = a.id(); + if (original_input_set.find(id) != original_input_set.end()) { + throw std::invalid_argument( + "[compile] Attempting to compile a function with uncaptured inputs is not allowed."); + } if (cache.find(id) != cache.end()) { return; } @@ -833,7 +839,7 @@ std::function(const std::vector&)> compile( std::unordered_map>> parents_map; std::tie(entry.tape, parents_map) = - compile_dfs(entry.inputs, entry.outputs); + compile_dfs(entry.inputs, entry.outputs, inputs); // Simplify the tape if (compile_mode() != CompileMode::no_simplify) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 1a2afaaa0..6a3b38218 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -972,7 +972,7 @@ void write_signature( {"threadgroups_per_grid", "uint3"}, {"threads_per_grid", "uint3"}, {"threads_per_simdgroup", "uint"}, - {"thread_per_threadgroup", "uint3"}, + {"threads_per_threadgroup", "uint3"}, }; std::vector> attrs; for (const auto& [attr, dtype] : metal_attributes) { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b07e965e9..92e6e6bdc 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -302,20 +302,20 @@ void init_fast(nb::module_& parent_module) { A jit-compiled custom Metal kernel defined from a source string. Args: - name (str): Name for the kernel. - input_names (List[str]): The parameter names of the inputs in the - function signature. - output_names (List[str]): The parameter names of the outputs in the + name (str): Name for the kernel. + input_names (List[str]): The parameter names of the inputs in the function signature. - source (str): Source code. This is the body of a function in Metal, - the function signature will be automatically generated. - header (str): Header source code to include before the main function. - Useful for helper functions or includes that should live outside of - the main function body. - ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous - before the kernel runs. Default: ``True``. - atomic_outputs (bool): Whether to use atomic outputs in the function signature - e.g. ``device atomic``. Default: ``False``. + output_names (List[str]): The parameter names of the outputs in the + function signature. + source (str): Source code. This is the body of a function in Metal, + the function signature will be automatically generated. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of + the main function body. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + atomic_outputs (bool): Whether to use atomic outputs in the function signature + e.g. ``device atomic``. Default: ``False``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index bdc7a1bff..af74cdaa5 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -733,6 +733,31 @@ class TestCompile(mlx_tests.MLXTestCase): expected = fn(x) self.assertTrue(mx.array_equal(expected, out)) + def test_compile_without_captured_inputs(self): + x = mx.array([1, 2, 3]) + 2 + + def fn(a): + y = x + 1 + return a + y + + with self.assertRaises(ValueError): + y = mx.compile(fn)(x) + + x = mx.array([1.0, 2.0]) + mx.array([1.0, 2.0]) + y = None + + def fn(x): + nonlocal y + if y is None: + y = mx.array([1.0, 2.0]) + + y = y + x + return y + + fn(x) + with self.assertRaises(ValueError): + y = mx.compile(fn)(x) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index db107eec1..f989783a2 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -689,6 +689,27 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out[0], mx.exp(a))) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_attributes(self): + a = mx.zeros(shape=(1, 1)) + kernel = mx.fast.metal_kernel( + name="test_fun", + input_names=["a"], + output_names=["out"], + source=""" + out[0] = threads_per_threadgroup.x; + """, + ) + out = kernel( + inputs=[a], + grid=(2, 1, 1), + threadgroup=(2, 1, 1), + output_shapes=[(1, 1)], + output_dtypes=[mx.uint32], + stream=mx.gpu, + )[0] + self.assertEqual(out.item(), 2) + if __name__ == "__main__": unittest.main()