mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -39,34 +39,33 @@ py::list to_list(array& a, size_t index, int dim) {
|
||||
}
|
||||
|
||||
auto to_scalar(array& a) {
|
||||
bool retain_graph = a.is_tracer();
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return py::cast(a.item<bool>(retain_graph));
|
||||
return py::cast(a.item<bool>());
|
||||
case uint8:
|
||||
return py::cast(a.item<uint8_t>(retain_graph));
|
||||
return py::cast(a.item<uint8_t>());
|
||||
case uint16:
|
||||
return py::cast(a.item<uint16_t>(retain_graph));
|
||||
return py::cast(a.item<uint16_t>());
|
||||
case uint32:
|
||||
return py::cast(a.item<uint32_t>(retain_graph));
|
||||
return py::cast(a.item<uint32_t>());
|
||||
case uint64:
|
||||
return py::cast(a.item<uint64_t>(retain_graph));
|
||||
return py::cast(a.item<uint64_t>());
|
||||
case int8:
|
||||
return py::cast(a.item<int8_t>(retain_graph));
|
||||
return py::cast(a.item<int8_t>());
|
||||
case int16:
|
||||
return py::cast(a.item<int16_t>(retain_graph));
|
||||
return py::cast(a.item<int16_t>());
|
||||
case int32:
|
||||
return py::cast(a.item<int32_t>(retain_graph));
|
||||
return py::cast(a.item<int32_t>());
|
||||
case int64:
|
||||
return py::cast(a.item<int64_t>(retain_graph));
|
||||
return py::cast(a.item<int64_t>());
|
||||
case float16:
|
||||
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<float16_t>()));
|
||||
case float32:
|
||||
return py::cast(a.item<float>(retain_graph));
|
||||
return py::cast(a.item<float>());
|
||||
case bfloat16:
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>(retain_graph)));
|
||||
return py::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||
case complex64:
|
||||
return py::cast(a.item<std::complex<float>>(retain_graph));
|
||||
return py::cast(a.item<std::complex<float>>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +73,7 @@ py::object tolist(array& a) {
|
||||
if (a.ndim() == 0) {
|
||||
return to_scalar(a);
|
||||
}
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
py::object pl;
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
@@ -527,7 +526,7 @@ void init_array(py::module_& m) {
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
eval({a}, a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
@@ -751,7 +750,7 @@ void init_array(py::module_& m) {
|
||||
"__repr__",
|
||||
[](array& a) {
|
||||
if (!a.is_evaled()) {
|
||||
a.eval(a.is_tracer());
|
||||
a.eval();
|
||||
}
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
|
@@ -345,19 +345,15 @@ class PyFileWriter : public io::Writer {
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph_) {
|
||||
bool retain_graph = retain_graph_.value_or(a.is_tracer());
|
||||
void mlx_save_helper(py::object file, array a) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a, retain_graph);
|
||||
save(py::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, retain_graph);
|
||||
save(writer, a);
|
||||
}
|
||||
|
||||
return;
|
||||
@@ -414,26 +410,23 @@ void mlx_savez_helper(
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save(writer, a, /*retain_graph=*/a.is_tracer());
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<bool> retain_graph) {
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
save_safetensors(writer, arrays_map, retain_graph);
|
||||
save_safetensors(writer, arrays_map);
|
||||
}
|
||||
|
||||
return;
|
||||
|
@@ -17,19 +17,13 @@ using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(
|
||||
py::object file,
|
||||
array a,
|
||||
std::optional<bool> retain_graph = std::nullopt);
|
||||
void mlx_save_helper(py::object file, array a);
|
||||
void mlx_savez_helper(
|
||||
py::object file,
|
||||
py::args args,
|
||||
|
@@ -2902,20 +2902,14 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
"arr"_a,
|
||||
py::pos_only(),
|
||||
"retain_graph"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save(file: str, arr: array, / , retain_graph: Optional[bool] = None)
|
||||
save(file: str, arr: array)
|
||||
|
||||
Save the array to a binary file in ``.npy`` format.
|
||||
|
||||
Args:
|
||||
file (str): File to which the array is saved
|
||||
arr (array): Array to be saved.
|
||||
retain_graph (bool, optional): Whether or not to retain the graph
|
||||
during array evaluation. If left unspecified the graph is retained
|
||||
only if saving is done in a function transformation. Default: ``None``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"savez",
|
||||
@@ -2999,11 +2993,8 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_safetensor_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
py::pos_only(),
|
||||
"retain_graph"_a = std::nullopt,
|
||||
py::kw_only(),
|
||||
R"pbdoc(
|
||||
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
|
||||
save_safetensors(file: str, arrays: Dict[str, array])
|
||||
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@@ -3012,9 +3003,6 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved>
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
retain_graph (bool, optional): Whether or not to retain the graph
|
||||
during array evaluation. If left unspecified the graph is retained
|
||||
only if saving is done in a function transformation. Default: ``None``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"where",
|
||||
|
@@ -440,11 +440,10 @@ auto py_vmap(
|
||||
void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args, bool retain_graph) {
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
eval(arrays, retain_graph);
|
||||
eval(arrays);
|
||||
},
|
||||
"retain_graph"_a = false,
|
||||
R"pbdoc(
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
@@ -453,9 +452,6 @@ void init_transforms(py::module_& m) {
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
retain_graph (bool): Indicate that the graph structure should be
|
||||
preserved. This option is intended to enable function transforms
|
||||
which contain control flow based on the value of an array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
|
@@ -259,6 +259,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
||||
|
||||
def test_update_state(self):
|
||||
y = mx.array([1.0])
|
||||
state = mx.zeros((2,))
|
||||
|
||||
def fn(y, x):
|
||||
nonlocal state
|
||||
x = y * x
|
||||
state = state + x
|
||||
return x.sum()
|
||||
|
||||
x = mx.ones((2,))
|
||||
mx.grad(fn)(y, x)
|
||||
mx.eval(state)
|
||||
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -15,18 +15,13 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
||||
|
||||
def test_retain_graph(self):
|
||||
def fun(x, retain_graph):
|
||||
def fun(x):
|
||||
y = 3 * x
|
||||
mx.eval(y, retain_graph=retain_graph)
|
||||
mx.eval(y)
|
||||
return 2 * y
|
||||
|
||||
dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
|
||||
dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dfun_dx_1(mx.array(1.0))
|
||||
|
||||
y = dfun_dx_2(mx.array(1.0))
|
||||
dfun_dx = mx.grad(fun)
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user