diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 8b7f24c91..ced958b13 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -922,7 +922,7 @@ std::function(const std::vector&)> compile( } std::function(const std::vector&)> compile( - std::vector(fun)(const std::vector&), + std::vector (*fun)(const std::vector&), bool shapeless /* = false */) { if (detail::skip_compile()) { return fun; diff --git a/mlx/compile.h b/mlx/compile.h index 90b6a9c2a..a076cfbca 100644 --- a/mlx/compile.h +++ b/mlx/compile.h @@ -14,7 +14,7 @@ std::function(const std::vector&)> compile( bool shapeless = false); std::function(const std::vector&)> compile( - std::vector(fun)(const std::vector&), + std::vector (*fun)(const std::vector&), bool shapeless = false); // Convert capture-less lambdas to function pointers.