mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 11:16:38 +08:00
make a move on compile
This commit is contained in:
parent
c38d43153b
commit
264e9ad57e
@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
|
47
mlx/compile.cpp
Normal file
47
mlx/compile.cpp
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <iostream> // TODO
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// class CompilerCache {
|
||||
// std::unordered_map
|
||||
// }
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
return (size_t)*fnPointer;
|
||||
}
|
||||
|
||||
int g(int, int) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||
// Not doing too much at the moment
|
||||
// std::cout << getAddress(fun) << std::endl;
|
||||
return [&fun](const std::vector<array>& inputs) {
|
||||
std::cout << getAddress(fun) << std::endl;
|
||||
// getAddress(std::function<int(int, int)>(g));
|
||||
//
|
||||
// std::cout << getAddress(fun) << std::endl;
|
||||
// Step 1 check the cache for the function.
|
||||
// If it's in the cache check the shapes and types
|
||||
// If they match then run the cached function,
|
||||
//
|
||||
// What exactly is the cached function?
|
||||
// The return has to be the outputs of fun(inputs) which point to the
|
||||
// correct inputs So we need to store a tape of primitives -> inputs (shape,
|
||||
// dtype), outputs (shape, dtype) We need a level of indirection id to input
|
||||
// to store the inputs so we can
|
||||
// T
|
||||
// Because eval will just want some pointers to arrays
|
||||
// So you go through and set the
|
||||
return fun(inputs);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,10 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "array.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Compile takes a function and returns a new function
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
|
||||
/** Fuse equivalent arrays to avoid duplicate execution. */
|
||||
void simplify(const std::vector<array>& outputs);
|
||||
|
||||
|
@ -20,6 +20,7 @@ target_sources(tests PRIVATE
|
||||
arg_reduce_tests.cpp
|
||||
autograd_tests.cpp
|
||||
blas_tests.cpp
|
||||
compile_tests.cpp
|
||||
creations_tests.cpp
|
||||
device_tests.cpp
|
||||
eval_tests.cpp
|
||||
|
17
tests/compile_tests.cpp
Normal file
17
tests/compile_tests.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
std::vector<array> simple_fun(const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + inputs[1]};
|
||||
};
|
||||
|
||||
TEST_CASE("test simple compile") {
|
||||
auto compfn = compile(simple_fun);
|
||||
auto out = compfn({array(1.0), array(2.0)})[0];
|
||||
CHECK_EQ(out.item<float>(), 3.0f);
|
||||
}
|
Loading…
Reference in New Issue
Block a user