mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
add test
This commit is contained in:
@@ -294,6 +294,11 @@ class array {
|
|||||||
return array_desc_->siblings;
|
return array_desc_->siblings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** The array's position in the sibling list. */
|
||||||
|
int sibling_position() const {
|
||||||
|
return array_desc_->position;
|
||||||
|
}
|
||||||
|
|
||||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||||
array_desc_->siblings = std::move(siblings);
|
array_desc_->siblings = std::move(siblings);
|
||||||
array_desc_->position = position;
|
array_desc_->position = position;
|
||||||
|
|||||||
@@ -414,55 +414,58 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& original_inputs) {
|
const std::vector<array>& original_inputs) {
|
||||||
std::function<void(const array&)> recurse;
|
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> input_set;
|
|
||||||
std::unordered_set<std::uintptr_t> original_input_set;
|
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||||
parents_map;
|
parents_map;
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
{
|
||||||
input_set.insert(inputs[i].id());
|
std::function<void(const array&)> recurse;
|
||||||
original_input_set.insert(original_inputs[i].id());
|
std::unordered_set<std::uintptr_t> input_set;
|
||||||
}
|
std::unordered_set<std::uintptr_t> original_input_set;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
input_set.insert(inputs[i].id());
|
||||||
|
original_input_set.insert(original_inputs[i].id());
|
||||||
|
}
|
||||||
|
|
||||||
// DFS the graph to build the tape, and log parents and scalars
|
// DFS the graph to build the tape, and log parents and scalars
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
recurse = [&](const array& a) {
|
recurse = [&](const array& a) {
|
||||||
auto id = a.id();
|
auto id = a.id();
|
||||||
if (original_input_set.find(id) != original_input_set.end()) {
|
if (original_input_set.find(id) != original_input_set.end()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
|
"[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
|
||||||
}
|
}
|
||||||
if (cache.find(id) != cache.end()) {
|
if (cache.find(id) != cache.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < a.inputs().size(); i++) {
|
for (int i = 0; i < a.inputs().size(); i++) {
|
||||||
auto& in = a.inputs()[i];
|
auto& in = a.inputs()[i];
|
||||||
parents_map[in.id()].push_back({a, i});
|
parents_map[in.id()].push_back({a, i});
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
parents_map[in.id()].push_back({s, i});
|
||||||
|
}
|
||||||
|
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||||
|
// of future optimizations)
|
||||||
|
if (input_set.find(a.id()) == input_set.end()) {
|
||||||
|
recurse(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
for (auto& s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
parents_map[in.id()].push_back({s, i});
|
cache.insert(s.id());
|
||||||
}
|
|
||||||
// Don't recurse on inputs (but add them to the tape for the purpose
|
|
||||||
// of future optimizations)
|
|
||||||
if (input_set.find(a.id()) == input_set.end()) {
|
|
||||||
recurse(in);
|
|
||||||
}
|
}
|
||||||
|
tape.push_back(a);
|
||||||
|
};
|
||||||
|
for (auto& a : outputs) {
|
||||||
|
recurse(a);
|
||||||
}
|
}
|
||||||
cache.insert(id);
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
cache.insert(s.id());
|
|
||||||
}
|
|
||||||
tape.push_back(a);
|
|
||||||
};
|
|
||||||
for (auto& a : outputs) {
|
|
||||||
recurse(a);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deep copy the tape and parents map
|
// Deep copy the tape and parents map while preserving inputs and outputs
|
||||||
std::vector<array> new_tape;
|
std::vector<array> new_tape;
|
||||||
std::unordered_set<uintptr_t> io_set;
|
std::unordered_set<uintptr_t> io_set;
|
||||||
std::unordered_map<uintptr_t, array> old_to_new;
|
std::unordered_map<uintptr_t, array> old_to_new;
|
||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
|
old_to_new.insert({o.id(), o});
|
||||||
io_set.insert(o.id());
|
io_set.insert(o.id());
|
||||||
}
|
}
|
||||||
for (auto& i : inputs) {
|
for (auto& i : inputs) {
|
||||||
@@ -483,7 +486,6 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
inputs.push_back(old_to_new.find(i.id())->second);
|
inputs.push_back(old_to_new.find(i.id())->second);
|
||||||
}
|
}
|
||||||
if (arr.siblings().size() > 0) {
|
if (arr.siblings().size() > 0) {
|
||||||
// use make_arrays
|
|
||||||
std::vector<Dtype> types;
|
std::vector<Dtype> types;
|
||||||
std::vector<Shape> shapes;
|
std::vector<Shape> shapes;
|
||||||
auto out = arr.outputs();
|
auto out = arr.outputs();
|
||||||
@@ -496,8 +498,7 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
for (int i = 0; i < out.size(); ++i) {
|
for (int i = 0; i < out.size(); ++i) {
|
||||||
old_to_new.insert({out[i].id(), as[i]});
|
old_to_new.insert({out[i].id(), as[i]});
|
||||||
}
|
}
|
||||||
// TODO maybe need to preserve position of sibling that is in tape
|
new_tape.push_back(as[arr.sibling_position()]);
|
||||||
new_tape.push_back(as[0]);
|
|
||||||
} else {
|
} else {
|
||||||
auto a = array(
|
auto a = array(
|
||||||
arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
|
arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
|
||||||
@@ -505,7 +506,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
new_tape.push_back(a);
|
new_tape.push_back(a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
io_set.clear();
|
||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
|
if (!(io_set.insert(o.id()).second)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
for (auto& i : o.inputs()) {
|
for (auto& i : o.inputs()) {
|
||||||
i = old_to_new.find(i.id())->second;
|
i = old_to_new.find(i.id())->second;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1134,6 +1134,30 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
a = fun2(mx.array(-1.0))
|
a = fun2(mx.array(-1.0))
|
||||||
self.assertEqual(a.item(), 1.0)
|
self.assertEqual(a.item(), 1.0)
|
||||||
|
|
||||||
|
def test_multiple_compile_same_capture(self):
|
||||||
|
def fun(do_compile):
|
||||||
|
t = mx.ones((10,))
|
||||||
|
u = (1.0 - t) * 0.0 + t * 3.0
|
||||||
|
|
||||||
|
o = mx.ones((6,))
|
||||||
|
b = o[:, None] * u
|
||||||
|
|
||||||
|
c = b * mx.ones_like(u)
|
||||||
|
|
||||||
|
a = mx.ones((6,))
|
||||||
|
if do_compile:
|
||||||
|
d = mx.compile(lambda x: x @ b)(a)
|
||||||
|
e = mx.compile(lambda x: x @ c.T)(d)
|
||||||
|
else:
|
||||||
|
d = a @ b
|
||||||
|
e = d @ c.T
|
||||||
|
return e
|
||||||
|
|
||||||
|
out = fun(True)
|
||||||
|
mx.eval(out)
|
||||||
|
expected = fun(False)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user