mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-13 20:56:45 +08:00
make a move on compile
This commit is contained in:
parent
c38d43153b
commit
264e9ad57e
@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
|
47
mlx/compile.cpp
Normal file
47
mlx/compile.cpp
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
#include <iostream> // TODO
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// class CompilerCache {
|
||||||
|
// std::unordered_map
|
||||||
|
// }
|
||||||
|
|
||||||
|
template <typename T, typename... U>
|
||||||
|
size_t getAddress(std::function<T(U...)> f) {
|
||||||
|
typedef T(fnType)(U...);
|
||||||
|
fnType** fnPointer = f.template target<fnType*>();
|
||||||
|
return (size_t)*fnPointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int g(int, int) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||||
|
// Not doing too much at the moment
|
||||||
|
// std::cout << getAddress(fun) << std::endl;
|
||||||
|
return [&fun](const std::vector<array>& inputs) {
|
||||||
|
std::cout << getAddress(fun) << std::endl;
|
||||||
|
// getAddress(std::function<int(int, int)>(g));
|
||||||
|
//
|
||||||
|
// std::cout << getAddress(fun) << std::endl;
|
||||||
|
// Step 1 check the cache for the function.
|
||||||
|
// If it's in the cache check the shapes and types
|
||||||
|
// If they match then run the cached function,
|
||||||
|
//
|
||||||
|
// What exactly is the cached function?
|
||||||
|
// The return has to be the outputs of fun(inputs) which point to the
|
||||||
|
// correct inputs So we need to store a tape of primitives -> inputs (shape,
|
||||||
|
// dtype), outputs (shape, dtype) We need a level of indirection id to input
|
||||||
|
// to store the inputs so we can
|
||||||
|
// T
|
||||||
|
// Because eval will just want some pointers to arrays
|
||||||
|
// So you go through and set the
|
||||||
|
return fun(inputs);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Compile takes a function and returns a new function
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||||
|
|
||||||
/** Fuse equivalent arrays to avoid duplicate execution. */
|
/** Fuse equivalent arrays to avoid duplicate execution. */
|
||||||
void simplify(const std::vector<array>& outputs);
|
void simplify(const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ target_sources(tests PRIVATE
|
|||||||
arg_reduce_tests.cpp
|
arg_reduce_tests.cpp
|
||||||
autograd_tests.cpp
|
autograd_tests.cpp
|
||||||
blas_tests.cpp
|
blas_tests.cpp
|
||||||
|
compile_tests.cpp
|
||||||
creations_tests.cpp
|
creations_tests.cpp
|
||||||
device_tests.cpp
|
device_tests.cpp
|
||||||
eval_tests.cpp
|
eval_tests.cpp
|
||||||
|
17
tests/compile_tests.cpp
Normal file
17
tests/compile_tests.cpp
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
std::vector<array> simple_fun(const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{inputs[0] + inputs[1]};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_CASE("test simple compile") {
|
||||||
|
auto compfn = compile(simple_fun);
|
||||||
|
auto out = compfn({array(1.0), array(2.0)})[0];
|
||||||
|
CHECK_EQ(out.item<float>(), 3.0f);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user