mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
parent
6e81c3e164
commit
41cc7bdfdb
@ -15,8 +15,8 @@ namespace mlx::core {
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
std::string get_name(uintptr_t id) {
|
||||
auto it = names.find(id);
|
||||
std::string get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
// [A, B, ..., Z, AA, AB, ...]
|
||||
@ -27,15 +27,11 @@ struct NodeNamer {
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({id, name});
|
||||
names.insert({x.id(), name});
|
||||
return name;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
return get_name(x.id());
|
||||
}
|
||||
};
|
||||
|
||||
void depth_first_traversal(
|
||||
@ -124,15 +120,14 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||
// Node for primitive
|
||||
if (x.has_primitive()) {
|
||||
os << "{ ";
|
||||
os << namer.get_name(x.primitive_id());
|
||||
os << x.primitive_id();
|
||||
os << " [label =\"";
|
||||
x.primitive().print(os);
|
||||
os << "\", shape=rectangle]";
|
||||
os << "; }" << std::endl;
|
||||
// Arrows to primitive's inputs
|
||||
for (auto& a : x.inputs()) {
|
||||
os << namer.get_name(x.primitive_id()) << " -> "
|
||||
<< namer.get_name(a) << std::endl;
|
||||
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,8 +140,7 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
||||
os << namer.get_name(a);
|
||||
os << "; }" << std::endl;
|
||||
if (x.has_primitive()) {
|
||||
os << namer.get_name(a) << " -> "
|
||||
<< namer.get_name(x.primitive_id()) << std::endl;
|
||||
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -532,25 +532,12 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<ArrayAt>(
|
||||
auto array_at_class = py::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
)pbdoc");
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m,
|
||||
@ -573,6 +560,21 @@ void init_array(py::module_& m) {
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
array_at_class
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
|
@ -12,6 +12,7 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
auto device_class = py::class_<Device>(m, "Device");
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
@ -23,8 +24,7 @@ void init_device(py::module_& m) {
|
||||
},
|
||||
py::prepend());
|
||||
|
||||
py::class_<Device>(m, "Device")
|
||||
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_readonly("type", &Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
|
@ -438,6 +438,9 @@ auto py_vmap(
|
||||
}
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
@ -445,6 +448,8 @@ void init_transforms(py::module_& m) {
|
||||
eval(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
eval(*args) -> None
|
||||
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
@ -476,6 +481,9 @@ void init_transforms(py::module_& m) {
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
R"pbdoc(
|
||||
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
@ -517,6 +525,8 @@ void init_transforms(py::module_& m) {
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
R"pbdoc(
|
||||
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
@ -549,6 +559,8 @@ void init_transforms(py::module_& m) {
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
@ -615,6 +627,8 @@ void init_transforms(py::module_& m) {
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -645,6 +659,8 @@ void init_transforms(py::module_& m) {
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
R"pbdoc(
|
||||
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
|
||||
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -670,6 +686,8 @@ void init_transforms(py::module_& m) {
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
simplify(*args) -> None
|
||||
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
|
12
setup.py
12
setup.py
@ -135,18 +135,6 @@ class GenerateStubs(Command):
|
||||
|
||||
def run(self) -> None:
|
||||
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"])
|
||||
# Note, sed inplace on macos requires a backup prefix, delete the file after its generated
|
||||
# this sed is needed to replace references from py::cpp_function to a generic Callable
|
||||
subprocess.run(
|
||||
[
|
||||
"sed",
|
||||
"-i",
|
||||
"''",
|
||||
"s/cpp_function/typing.Callable/g",
|
||||
"python/mlx/core/__init__.pyi",
|
||||
]
|
||||
)
|
||||
subprocess.run(["rm", "python/mlx/core/__init__.pyi''"])
|
||||
|
||||
|
||||
# Read the content of README.md
|
||||
|
Loading…
Reference in New Issue
Block a user