| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  | // Copyright © 2023-2024 Apple Inc.
 | 
					
						
							| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  | #include <nanobind/nanobind.h>
 | 
					
						
							|  |  |  | #include <nanobind/stl/optional.h>
 | 
					
						
							|  |  |  | #include <nanobind/stl/string.h>
 | 
					
						
							|  |  |  | #include <nanobind/stl/unordered_map.h>
 | 
					
						
							|  |  |  | #include <nanobind/stl/variant.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-27 05:06:55 -05:00
										 |  |  | #include <optional>
 | 
					
						
							|  |  |  | #include <string>
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | #include <unordered_map>
 | 
					
						
							|  |  |  | #include <variant>
 | 
					
						
							| 
									
										
										
										
											2024-01-19 23:06:05 +01:00
										 |  |  | #include "mlx/io.h"
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 08:45:39 +09:00
										 |  |  | namespace mx = mlx::core; | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  | namespace nb = nanobind; | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-19 23:06:05 +01:00
										 |  |  | using LoadOutputTypes = std::variant< | 
					
						
							| 
									
										
										
										
											2024-12-12 08:45:39 +09:00
										 |  |  |     mx::array, | 
					
						
							|  |  |  |     std::unordered_map<std::string, mx::array>, | 
					
						
							|  |  |  |     mx::SafetensorsLoad, | 
					
						
							|  |  |  |     mx::GGUFLoad>; | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 08:45:39 +09:00
										 |  |  | mx::SafetensorsLoad mlx_load_safetensor_helper( | 
					
						
							|  |  |  |     nb::object file, | 
					
						
							|  |  |  |     mx::StreamOrDevice s); | 
					
						
							| 
									
										
										
										
											2024-02-08 22:33:15 -05:00
										 |  |  | void mlx_save_safetensor_helper( | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  |     nb::object file, | 
					
						
							|  |  |  |     nb::dict d, | 
					
						
							|  |  |  |     std::optional<nb::dict> m); | 
					
						
							| 
									
										
										
										
											2024-02-08 22:33:15 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-12 08:45:39 +09:00
										 |  |  | mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s); | 
					
						
							| 
									
										
										
										
											2023-12-27 05:06:55 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-19 23:06:05 +01:00
										 |  |  | void mlx_save_gguf_helper( | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  |     nb::object file, | 
					
						
							|  |  |  |     nb::dict d, | 
					
						
							|  |  |  |     std::optional<nb::dict> m); | 
					
						
							| 
									
										
										
										
											2024-01-10 16:22:48 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-19 23:06:05 +01:00
										 |  |  | LoadOutputTypes mlx_load_helper( | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  |     nb::object file, | 
					
						
							| 
									
										
										
										
											2023-12-27 05:06:55 -05:00
										 |  |  |     std::optional<std::string> format, | 
					
						
							| 
									
										
										
										
											2024-01-19 23:06:05 +01:00
										 |  |  |     bool return_metadata, | 
					
						
							| 
									
										
										
										
											2024-12-12 08:45:39 +09:00
										 |  |  |     mx::StreamOrDevice s); | 
					
						
							|  |  |  | void mlx_save_helper(nb::object file, mx::array a); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | void mlx_savez_helper( | 
					
						
							| 
									
										
										
										
											2024-03-18 20:12:25 -07:00
										 |  |  |     nb::object file, | 
					
						
							|  |  |  |     nb::args args, | 
					
						
							|  |  |  |     const nb::kwargs& kwargs, | 
					
						
							| 
									
										
										
										
											2023-12-21 14:08:24 -08:00
										 |  |  |     bool compressed = false); |