mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
parent
6e81c3e164
commit
41cc7bdfdb
@ -15,8 +15,8 @@ namespace mlx::core {
|
|||||||
struct NodeNamer {
|
struct NodeNamer {
|
||||||
std::unordered_map<std::uintptr_t, std::string> names;
|
std::unordered_map<std::uintptr_t, std::string> names;
|
||||||
|
|
||||||
std::string get_name(uintptr_t id) {
|
std::string get_name(const array& x) {
|
||||||
auto it = names.find(id);
|
auto it = names.find(x.id());
|
||||||
if (it == names.end()) {
|
if (it == names.end()) {
|
||||||
// Get the next name in the sequence
|
// Get the next name in the sequence
|
||||||
// [A, B, ..., Z, AA, AB, ...]
|
// [A, B, ..., Z, AA, AB, ...]
|
||||||
@ -27,15 +27,11 @@ struct NodeNamer {
|
|||||||
var_num = (var_num - 1) / 26;
|
var_num = (var_num - 1) / 26;
|
||||||
}
|
}
|
||||||
std::string name(letters.rbegin(), letters.rend());
|
std::string name(letters.rbegin(), letters.rend());
|
||||||
names.insert({id, name});
|
names.insert({x.id(), name});
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_name(const array& x) {
|
|
||||||
return get_name(x.id());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void depth_first_traversal(
|
void depth_first_traversal(
|
||||||
@ -124,15 +120,14 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
|||||||
// Node for primitive
|
// Node for primitive
|
||||||
if (x.has_primitive()) {
|
if (x.has_primitive()) {
|
||||||
os << "{ ";
|
os << "{ ";
|
||||||
os << namer.get_name(x.primitive_id());
|
os << x.primitive_id();
|
||||||
os << " [label =\"";
|
os << " [label =\"";
|
||||||
x.primitive().print(os);
|
x.primitive().print(os);
|
||||||
os << "\", shape=rectangle]";
|
os << "\", shape=rectangle]";
|
||||||
os << "; }" << std::endl;
|
os << "; }" << std::endl;
|
||||||
// Arrows to primitive's inputs
|
// Arrows to primitive's inputs
|
||||||
for (auto& a : x.inputs()) {
|
for (auto& a : x.inputs()) {
|
||||||
os << namer.get_name(x.primitive_id()) << " -> "
|
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
|
||||||
<< namer.get_name(a) << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,8 +140,7 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
|||||||
os << namer.get_name(a);
|
os << namer.get_name(a);
|
||||||
os << "; }" << std::endl;
|
os << "; }" << std::endl;
|
||||||
if (x.has_primitive()) {
|
if (x.has_primitive()) {
|
||||||
os << namer.get_name(a) << " -> "
|
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
|
||||||
<< namer.get_name(x.primitive_id()) << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -532,25 +532,12 @@ void init_array(py::module_& m) {
|
|||||||
m.attr("bfloat16") = py::cast(bfloat16);
|
m.attr("bfloat16") = py::cast(bfloat16);
|
||||||
m.attr("complex64") = py::cast(complex64);
|
m.attr("complex64") = py::cast(complex64);
|
||||||
|
|
||||||
py::class_<ArrayAt>(
|
auto array_at_class = py::class_<ArrayAt>(
|
||||||
m,
|
m,
|
||||||
"_ArrayAt",
|
"_ArrayAt",
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A helper object to apply updates at specific indices.
|
A helper object to apply updates at specific indices.
|
||||||
)pbdoc")
|
)pbdoc");
|
||||||
.def(
|
|
||||||
py::init([](const array& x) { return ArrayAt(x); }),
|
|
||||||
"x"_a,
|
|
||||||
R"pbdoc(
|
|
||||||
__init__(self, x: array)
|
|
||||||
)pbdoc")
|
|
||||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
|
||||||
.def("add", &ArrayAt::add, "value"_a)
|
|
||||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
|
||||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
|
||||||
.def("divide", &ArrayAt::divide, "value"_a)
|
|
||||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
|
||||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
|
||||||
|
|
||||||
auto array_class = py::class_<array>(
|
auto array_class = py::class_<array>(
|
||||||
m,
|
m,
|
||||||
@ -573,6 +560,21 @@ void init_array(py::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array_at_class
|
||||||
|
.def(
|
||||||
|
py::init([](const array& x) { return ArrayAt(x); }),
|
||||||
|
"x"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
__init__(self, x: array)
|
||||||
|
)pbdoc")
|
||||||
|
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||||
|
.def("add", &ArrayAt::add, "value"_a)
|
||||||
|
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||||
|
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||||
|
.def("divide", &ArrayAt::divide, "value"_a)
|
||||||
|
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||||
|
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||||
|
|
||||||
array_class
|
array_class
|
||||||
.def_buffer([](array& a) {
|
.def_buffer([](array& a) {
|
||||||
// Eval if not already evaled
|
// Eval if not already evaled
|
||||||
@ -680,17 +682,17 @@ void init_array(py::module_& m) {
|
|||||||
|
|
||||||
* - array.at syntax
|
* - array.at syntax
|
||||||
- In-place syntax
|
- In-place syntax
|
||||||
* - ``x = x.at[idx].add(y)``
|
* - ``x = x.at[idx].add(y)``
|
||||||
- ``x[idx] += y``
|
- ``x[idx] += y``
|
||||||
* - ``x = x.at[idx].subtract(y)``
|
* - ``x = x.at[idx].subtract(y)``
|
||||||
- ``x[idx] -= y``
|
- ``x[idx] -= y``
|
||||||
* - ``x = x.at[idx].multiply(y)``
|
* - ``x = x.at[idx].multiply(y)``
|
||||||
- ``x[idx] *= y``
|
- ``x[idx] *= y``
|
||||||
* - ``x = x.at[idx].divide(y)``
|
* - ``x = x.at[idx].divide(y)``
|
||||||
- ``x[idx] /= y``
|
- ``x[idx] /= y``
|
||||||
* - ``x = x.at[idx].maximum(y)``
|
* - ``x = x.at[idx].maximum(y)``
|
||||||
- ``x[idx] = mx.maximum(x[idx], y)``
|
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||||
* - ``x = x.at[idx].minimum(y)``
|
* - ``x = x.at[idx].minimum(y)``
|
||||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
|
@ -12,6 +12,7 @@ using namespace py::literals;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
void init_device(py::module_& m) {
|
void init_device(py::module_& m) {
|
||||||
|
auto device_class = py::class_<Device>(m, "Device");
|
||||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||||
.value("cpu", Device::DeviceType::cpu)
|
.value("cpu", Device::DeviceType::cpu)
|
||||||
.value("gpu", Device::DeviceType::gpu)
|
.value("gpu", Device::DeviceType::gpu)
|
||||||
@ -23,8 +24,7 @@ void init_device(py::module_& m) {
|
|||||||
},
|
},
|
||||||
py::prepend());
|
py::prepend());
|
||||||
|
|
||||||
py::class_<Device>(m, "Device")
|
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||||
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
|
||||||
.def_readonly("type", &Device::type)
|
.def_readonly("type", &Device::type)
|
||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
|
@ -438,6 +438,9 @@ auto py_vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_transforms(py::module_& m) {
|
void init_transforms(py::module_& m) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const py::args& args) {
|
[](const py::args& args) {
|
||||||
@ -445,6 +448,8 @@ void init_transforms(py::module_& m) {
|
|||||||
eval(arrays);
|
eval(arrays);
|
||||||
},
|
},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
eval(*args) -> None
|
||||||
|
|
||||||
Evaluate an :class:`array` or tree of :class:`array`.
|
Evaluate an :class:`array` or tree of :class:`array`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -476,6 +481,9 @@ void init_transforms(py::module_& m) {
|
|||||||
"primals"_a,
|
"primals"_a,
|
||||||
"tangents"_a,
|
"tangents"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
|
||||||
|
|
||||||
|
|
||||||
Compute the Jacobian-vector product.
|
Compute the Jacobian-vector product.
|
||||||
|
|
||||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||||
@ -517,6 +525,8 @@ void init_transforms(py::module_& m) {
|
|||||||
"primals"_a,
|
"primals"_a,
|
||||||
"cotangents"_a,
|
"cotangents"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
|
||||||
|
|
||||||
Compute the vector-Jacobian product.
|
Compute the vector-Jacobian product.
|
||||||
|
|
||||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||||
@ -549,6 +559,8 @@ void init_transforms(py::module_& m) {
|
|||||||
"argnums"_a = std::nullopt,
|
"argnums"_a = std::nullopt,
|
||||||
"argnames"_a = std::vector<std::string>{},
|
"argnames"_a = std::vector<std::string>{},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||||
|
|
||||||
Returns a function which computes the value and gradient of ``fun``.
|
Returns a function which computes the value and gradient of ``fun``.
|
||||||
|
|
||||||
The function passed to :func:`value_and_grad` should return either
|
The function passed to :func:`value_and_grad` should return either
|
||||||
@ -615,6 +627,8 @@ void init_transforms(py::module_& m) {
|
|||||||
"argnums"_a = std::nullopt,
|
"argnums"_a = std::nullopt,
|
||||||
"argnames"_a = std::vector<std::string>{},
|
"argnames"_a = std::vector<std::string>{},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||||
|
|
||||||
Returns a function which computes the gradient of ``fun``.
|
Returns a function which computes the gradient of ``fun``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -645,6 +659,8 @@ void init_transforms(py::module_& m) {
|
|||||||
"in_axes"_a = 0,
|
"in_axes"_a = 0,
|
||||||
"out_axes"_a = 0,
|
"out_axes"_a = 0,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
|
||||||
|
|
||||||
Returns a vectorized version of ``fun``.
|
Returns a vectorized version of ``fun``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -670,6 +686,8 @@ void init_transforms(py::module_& m) {
|
|||||||
simplify(arrays);
|
simplify(arrays);
|
||||||
},
|
},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
simplify(*args) -> None
|
||||||
|
|
||||||
Simplify the graph that computes the arrays.
|
Simplify the graph that computes the arrays.
|
||||||
|
|
||||||
Run a few fast graph simplification operations to reuse computation and
|
Run a few fast graph simplification operations to reuse computation and
|
||||||
|
12
setup.py
12
setup.py
@ -135,18 +135,6 @@ class GenerateStubs(Command):
|
|||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"])
|
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"])
|
||||||
# Note, sed inplace on macos requires a backup prefix, delete the file after its generated
|
|
||||||
# this sed is needed to replace references from py::cpp_function to a generic Callable
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
"sed",
|
|
||||||
"-i",
|
|
||||||
"''",
|
|
||||||
"s/cpp_function/typing.Callable/g",
|
|
||||||
"python/mlx/core/__init__.pyi",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
subprocess.run(["rm", "python/mlx/core/__init__.pyi''"])
|
|
||||||
|
|
||||||
|
|
||||||
# Read the content of README.md
|
# Read the content of README.md
|
||||||
|
Loading…
Reference in New Issue
Block a user