make a move on compile

This commit is contained in:
Awni Hannun 2024-01-11 06:27:44 -08:00
parent c38d43153b
commit 264e9ad57e
5 changed files with 71 additions and 1 deletions

View File

@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp

47
mlx/compile.cpp Normal file
View File

@ -0,0 +1,47 @@
// Copyright © 2023 Apple Inc.
#include <iostream> // TODO
#include "mlx/transforms.h"
namespace mlx::core {
// class CompilerCache {
// std::unordered_map
// }
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
fnType** fnPointer = f.template target<fnType*>();
return (size_t)*fnPointer;
}
int g(int, int) {
return 2;
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
// Not doing too much at the moment
// std::cout << getAddress(fun) << std::endl;
return [&fun](const std::vector<array>& inputs) {
std::cout << getAddress(fun) << std::endl;
// getAddress(std::function<int(int, int)>(g));
//
// std::cout << getAddress(fun) << std::endl;
// Step 1 check the cache for the function.
// If it's in the cache check the shapes and types
// If they match then run the cached function,
//
// What exactly is the cached function?
// The return has to be the outputs of fun(inputs) which point to the
// correct inputs So we need to store a tape of primitives -> inputs (shape,
// dtype), outputs (shape, dtype) We need a level of indirection id to input
// to store the inputs so we can
// T
// Because eval will just want some pointers to arrays
// So you go through and set the
return fun(inputs);
};
}
} // namespace mlx::core

View File

@ -2,10 +2,14 @@
#pragma once
#include "array.h"
#include "mlx/array.h"
namespace mlx::core {
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Fuse equivalent arrays to avoid duplicate execution. */
void simplify(const std::vector<array>& outputs);

View File

@ -20,6 +20,7 @@ target_sources(tests PRIVATE
arg_reduce_tests.cpp
autograd_tests.cpp
blas_tests.cpp
compile_tests.cpp
creations_tests.cpp
device_tests.cpp
eval_tests.cpp

17
tests/compile_tests.cpp Normal file
View File

@ -0,0 +1,17 @@
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
std::vector<array> simple_fun(const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + inputs[1]};
};
TEST_CASE("test simple compile") {
auto compfn = compile(simple_fun);
auto out = compfn({array(1.0), array(2.0)})[0];
CHECK_EQ(out.item<float>(), 3.0f);
}