Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
}
auto to_scalar(array& a) {
bool retain_graph = a.is_tracer();
switch (a.dtype()) {
case bool_:
return py::cast(a.item<bool>(retain_graph));
return py::cast(a.item<bool>());
case uint8:
return py::cast(a.item<uint8_t>(retain_graph));
return py::cast(a.item<uint8_t>());
case uint16:
return py::cast(a.item<uint16_t>(retain_graph));
return py::cast(a.item<uint16_t>());
case uint32:
return py::cast(a.item<uint32_t>(retain_graph));
return py::cast(a.item<uint32_t>());
case uint64:
return py::cast(a.item<uint64_t>(retain_graph));
return py::cast(a.item<uint64_t>());
case int8:
return py::cast(a.item<int8_t>(retain_graph));
return py::cast(a.item<int8_t>());
case int16:
return py::cast(a.item<int16_t>(retain_graph));
return py::cast(a.item<int16_t>());
case int32:
return py::cast(a.item<int32_t>(retain_graph));
return py::cast(a.item<int32_t>());
case int64:
return py::cast(a.item<int64_t>(retain_graph));
return py::cast(a.item<int64_t>());
case float16:
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
return py::cast(static_cast<float>(a.item<float16_t>()));
case float32:
return py::cast(a.item<float>(retain_graph));
return py::cast(a.item<float>());
case bfloat16:
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
return py::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
return py::cast(a.item<std::complex<float>>(retain_graph));
return py::cast(a.item<std::complex<float>>());
}
}
@@ -74,7 +73,7 @@ py::object tolist(array& a) {
if (a.ndim() == 0) {
return to_scalar(a);
}
a.eval(a.is_tracer());
a.eval();
py::object pl;
switch (a.dtype()) {
case bool_:
@@ -527,7 +526,7 @@ void init_array(py::module_& m) {
.def_buffer([](array& a) {
// Eval if not already evaled
if (!a.is_evaled()) {
eval({a}, a.is_tracer());
a.eval();
}
return pybind11::buffer_info(
a.data<void>(),
@@ -751,7 +750,7 @@ void init_array(py::module_& m) {
"__repr__",
[](array& a) {
if (!a.is_evaled()) {
a.eval(a.is_tracer());
a.eval();
}
std::ostringstream os;
os << a;