mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Removes the retain_graph flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
committed by
GitHub
parent
449b43762e
commit
a611b0bc82
@@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
|
||||
}
|
||||
|
||||
auto to_scalar(array& a) {
|
||||
bool retain_graph = a.is_tracer();
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return py::cast(a.item<bool>(retain_graph));
|
||||
return py::cast(a.item<bool>());
|
||||
case uint8:
|
||||
return py::cast(a.item<uint8_t>(retain_graph));
|
||||
return py::cast(a.item<uint8_t>());
|
||||
case uint16:
|
||||
return py::cast(a.item<uint16_t>(retain_graph));
|
||||
return py::cast(a.item<uint16_t>());
|
||||
case uint32:
|
||||
return py::cast(a.item<uint32_t>(retain_graph));
|
||||
return py::cast(a.item<uint32_t>());
|
||||
case uint64:
|
||||
return py::cast(a.item<uint64_t>(retain_graph));
|
||||
return py::cast(a.item<uint64_t>());
|
||||
case int8:
|
||||
return py::cast(a.item<int8_t>(retain_graph));
|
||||
return py::cast(a.item<int8_t>());
|
||||
case int16:
|
||||
return py::cast(a.item<int16_t>(retain_graph));
|
||||
return py::cast(a.item<int16_t>());
|
||||
case int32:
|
||||
return py::cast(a.item<int32_t>(retain_graph));
|
||||
return py::cast(a.item<int32_t>());
|
||||
case int64:
|
||||
return py::cast(a.item<int64_t>(retain_graph));
|
||||
return py::cast(a.item<int64_t>());
|
||||
case float16:
|
||||
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<float16_t>()));
|
||||
case float32:
|
||||
return py::cast(a.item<float>(retain_graph));
|
||||
return py::cast(a.item<float>());
|
||||
case bfloat16:
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||
case complex64:
|
||||
return py::cast(a.item<std::complex<float>>(retain_graph));
|
||||
return py::cast(a.item<std::complex<float>>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +73,7 @@ py::object tolist(array& a) {
|
||||
if (a.ndim() == 0) {
|
||||
return to_scalar(a);
|
||||
}
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
py::object pl;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
@@ -527,7 +526,7 @@ void init_array(py::module_& m) {
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
eval({a}, a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
@@ -751,7 +750,7 @@ void init_array(py::module_& m) {
|
||||
"__repr__",
|
||||
[](array& a) {
|
||||
if (!a.is_evaled()) {
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
|
||||
Reference in New Issue
Block a user