diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 882bf93e06..c001b76a7c 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp diff --git a/mlx/compile.cpp b/mlx/compile.cpp new file mode 100644 index 0000000000..48ce8b8064 --- /dev/null +++ b/mlx/compile.cpp @@ -0,0 +1,47 @@ +// Copyright © 2023 Apple Inc. +#include // TODO +#include "mlx/transforms.h" + +namespace mlx::core { + +// class CompilerCache { +// std::unordered_map +// } + +template +size_t getAddress(std::function f) { + typedef T(fnType)(U...); + fnType** fnPointer = f.template target(); + return (size_t)*fnPointer; +} + +int g(int, int) { + return 2; +} + +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun) { + // Not doing too much at the moment + // std::cout << getAddress(fun) << std::endl; + return [&fun](const std::vector& inputs) { + std::cout << getAddress(fun) << std::endl; + // getAddress(std::function(g)); + // + // std::cout << getAddress(fun) << std::endl; + // Step 1 check the cache for the function. + // If it's in the cache check the shapes and types + // If they match then run the cached function, + // + // What exactly is the cached function? + // The return has to be the outputs of fun(inputs) which point to the + // correct inputs So we need to store a tape of primitives -> inputs (shape, + // dtype), outputs (shape, dtype) We need a level of indirection id to input + // to store the inputs so we can + // T + // Because eval will just want some pointers to arrays + // So you go through and set the + return fun(inputs); + }; +} + +} // namespace mlx::core diff --git a/mlx/transforms.h b/mlx/transforms.h index 813c5f7fd7..539f47efd9 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -2,10 +2,14 @@ #pragma once -#include "array.h" +#include "mlx/array.h" namespace mlx::core { +// Compile takes a function and returns a new function +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun); + /** Fuse equivalent arrays to avoid duplicate execution. */ void simplify(const std::vector& outputs); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dbc4992052..f120b2ee7b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources(tests PRIVATE arg_reduce_tests.cpp autograd_tests.cpp blas_tests.cpp + compile_tests.cpp creations_tests.cpp device_tests.cpp eval_tests.cpp diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp new file mode 100644 index 0000000000..e83f5179e4 --- /dev/null +++ b/tests/compile_tests.cpp @@ -0,0 +1,17 @@ +// Copyright © 2023 Apple Inc. + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +std::vector simple_fun(const std::vector& inputs) { + return std::vector{inputs[0] + inputs[1]}; +}; + +TEST_CASE("test simple compile") { + auto compfn = compile(simple_fun); + auto out = compfn({array(1.0), array(2.0)})[0]; + CHECK_EQ(out.item(), 3.0f); +}