Compile front-end (#476)

* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
This commit is contained in:
Awni Hannun
2024-01-26 13:45:30 -08:00
committed by GitHub
parent 874b739f3c
commit 8fa6b322b9
13 changed files with 1029 additions and 297 deletions

View File

@@ -47,6 +47,17 @@ array::array(
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
@@ -158,7 +169,22 @@ array::ArrayDesc::ArrayDesc(
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(shape);
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}