mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Optionally specify names for arrays when exporting (#1749)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							058d6ce683
						
					
				
				
					commit
					25b3a3e541
				
			| @@ -241,14 +241,20 @@ void init_export(nb::module_& m) { | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "export_to_dot", | ||||
|       [](nb::object file, const nb::args& args) { | ||||
|         std::vector<mx::array> arrays = tree_flatten(args); | ||||
|       [](nb::object file, const nb::args& args, const nb::kwargs& kwargs) { | ||||
|         std::vector<mx::array> arrays = | ||||
|             tree_flatten(nb::make_tuple(args, kwargs)); | ||||
|         mx::NodeNamer namer; | ||||
|         for (const auto& n : kwargs) { | ||||
|           namer.set_name( | ||||
|               nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first)); | ||||
|         } | ||||
|         if (nb::isinstance<nb::str>(file)) { | ||||
|           std::ofstream out(nb::cast<std::string>(file)); | ||||
|           mx::export_to_dot(out, arrays); | ||||
|           mx::export_to_dot(out, std::move(namer), arrays); | ||||
|         } else if (nb::hasattr(file, "write")) { | ||||
|           std::ostringstream out; | ||||
|           mx::export_to_dot(out, arrays); | ||||
|           mx::export_to_dot(out, std::move(namer), arrays); | ||||
|           auto write = file.attr("write"); | ||||
|           write(out.str()); | ||||
|         } else { | ||||
| @@ -259,19 +265,25 @@ void init_export(nb::module_& m) { | ||||
|       }, | ||||
|       "file"_a, | ||||
|       "args"_a, | ||||
|       "kwargs"_a, | ||||
|       R"pbdoc( | ||||
|         Export a graph to DOT format for visualization. | ||||
|  | ||||
|         A variable number of output arrays can be provided for exporting | ||||
|         The graph exported will recursively include all enevaluated inputs of | ||||
|         The graph exported will recursively include all unevaluated inputs of | ||||
|         the provided outputs. | ||||
|  | ||||
|         Args: | ||||
|             file (str): The file path to export to. | ||||
|             *args (array): The output arrays. | ||||
|             **kwargs (dict[str, array]): Provide some names for arrays in the | ||||
|               graph to make the result easier to parse. | ||||
|  | ||||
|         Example: | ||||
|           >>> a = mx.array(1) + mx.array(2) | ||||
|           >>> mx.export_to_dot("graph.dot", a) | ||||
|           >>> x = mx.array(1) | ||||
|           >>> y = mx.array(2) | ||||
|           >>> mx.export_to_dot("graph.dot", x + y, x=x, y=y) | ||||
|       )pbdoc"); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user