mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
							
								
								
									
										6
									
								
								mlx/io/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								mlx/io/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| target_sources( | ||||
|   mlx | ||||
|   PRIVATE | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp | ||||
| ) | ||||
							
								
								
									
										242
									
								
								mlx/io/load.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								mlx/io/load.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,242 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <cstring> | ||||
| #include <fstream> | ||||
| #include <limits> | ||||
| #include <sstream> | ||||
|  | ||||
| #include "mlx/io/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/primitives.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| // Adapted from | ||||
| // https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| static constexpr uint8_t MAGIC[] = { | ||||
|     0x93, | ||||
|     0x4e, | ||||
|     0x55, | ||||
|     0x4d, | ||||
|     0x50, | ||||
|     0x59, | ||||
| }; | ||||
|  | ||||
| inline bool is_big_endian_() { | ||||
|   union ByteOrder { | ||||
|     int32_t i; | ||||
|     uint8_t c[4]; | ||||
|   }; | ||||
|   ByteOrder b = {0x01234567}; | ||||
|  | ||||
|   return b.c[0] == 0x01; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check array | ||||
|  | ||||
|   a.eval(retain_graph); | ||||
|  | ||||
|   if (a.nbytes() == 0) { | ||||
|     throw std::invalid_argument("[save] cannot serialize an empty array"); | ||||
|   } | ||||
|  | ||||
|   if (!a.flags().contiguous) { | ||||
|     throw std::invalid_argument( | ||||
|         "[save] cannot serialize a non-contiguous array"); | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check file | ||||
|   if (!out_stream->good() || !out_stream->is_open()) { | ||||
|     throw std::runtime_error("[save] Failed to open " + out_stream->label()); | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Prepare header | ||||
|   std::ostringstream magic_ver_len; | ||||
|   magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6); | ||||
|  | ||||
|   std::string fortran_order = a.flags().col_contiguous ? "True" : "False"; | ||||
|   std::ostringstream header; | ||||
|   header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "'," | ||||
|          << " 'fortran_order': " << fortran_order << "," | ||||
|          << " 'shape': ("; | ||||
|   for (auto i : a.shape()) { | ||||
|     header << i << ", "; | ||||
|   } | ||||
|   header << ")}"; | ||||
|  | ||||
|   size_t header_len = static_cast<size_t>(header.tellp()); | ||||
|   bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max(); | ||||
|  | ||||
|   // Pad out magic + version + header_len + header + \n to be divisible by 16 | ||||
|   size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16; | ||||
|  | ||||
|   header << std::string(padding, ' ') << '\n'; | ||||
|  | ||||
|   if (is_v1) { | ||||
|     magic_ver_len << (char)0x01 << (char)0x00; | ||||
|  | ||||
|     uint16_t v1_header_len = header.tellp(); | ||||
|     const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len); | ||||
|  | ||||
|     if (!is_big_endian_()) { | ||||
|       magic_ver_len.write(len_bytes, 2); | ||||
|     } else { | ||||
|       magic_ver_len.write(len_bytes + 1, 1); | ||||
|       magic_ver_len.write(len_bytes, 1); | ||||
|     } | ||||
|   } else { | ||||
|     magic_ver_len << (char)0x02 << (char)0x00; | ||||
|  | ||||
|     uint32_t v2_header_len = header.tellp(); | ||||
|     const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len); | ||||
|  | ||||
|     if (!is_big_endian_()) { | ||||
|       magic_ver_len.write(len_bytes, 4); | ||||
|     } else { | ||||
|       magic_ver_len.write(len_bytes + 3, 1); | ||||
|       magic_ver_len.write(len_bytes + 2, 1); | ||||
|       magic_ver_len.write(len_bytes + 1, 1); | ||||
|       magic_ver_len.write(len_bytes, 1); | ||||
|     } | ||||
|   } | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Serialize array | ||||
|  | ||||
|   out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length()); | ||||
|   out_stream->write(header.str().c_str(), header.str().length()); | ||||
|   out_stream->write(a.data<char>(), a.nbytes()); | ||||
|  | ||||
|   return; | ||||
| } | ||||
|  | ||||
| /** Save array to file in .npy format */ | ||||
| void save(const std::string& file_, array a, bool retain_graph) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
|   // Add .npy to file name if it is not there | ||||
|   if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy") | ||||
|     file += ".npy"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save(std::make_shared<io::FileWriter>(file), a, retain_graph); | ||||
| } | ||||
|  | ||||
| /** Load array from reader in .npy format */ | ||||
| array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Open and check file | ||||
|   if (!in_stream->good() || !in_stream->is_open()) { | ||||
|     throw std::runtime_error("[load] Failed to open " + in_stream->label()); | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Read header and prepare array details | ||||
|  | ||||
|   // Read and check magic | ||||
|   char read_magic_and_ver[8]; | ||||
|   in_stream->read(read_magic_and_ver, 8); | ||||
|   if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) { | ||||
|     throw std::runtime_error("[load] Invalid header in " + in_stream->label()); | ||||
|   } | ||||
|  | ||||
|   // Read and check version | ||||
|   if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) { | ||||
|     throw std::runtime_error( | ||||
|         "[load] Unsupport npy format version in " + in_stream->label()); | ||||
|   } | ||||
|  | ||||
|   // Read header len and header | ||||
|   int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4; | ||||
|   size_t header_len; | ||||
|  | ||||
|   if (header_len_size == 2) { | ||||
|     uint16_t v1_header_len; | ||||
|     in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size); | ||||
|     header_len = v1_header_len; | ||||
|   } else { | ||||
|     uint32_t v2_header_len; | ||||
|     in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size); | ||||
|     header_len = v2_header_len; | ||||
|   } | ||||
|  | ||||
|   // Read the header | ||||
|   std::vector<char> buffer(header_len + 1); | ||||
|   in_stream->read(&buffer[0], header_len); | ||||
|   buffer[header_len] = 0; | ||||
|   std::string header(&buffer[0]); | ||||
|  | ||||
|   // Read data type from header | ||||
|   std::string dtype_str = header.substr(11, 3); | ||||
|   bool read_is_big_endian = dtype_str[0] == '>'; | ||||
|   Dtype dtype = dtype_from_array_protocol(dtype_str); | ||||
|  | ||||
|   // Read contiguity order | ||||
|   bool col_contiguous = header[34] == 'T'; | ||||
|  | ||||
|   // Read array shape from header | ||||
|   std::vector<int> shape; | ||||
|  | ||||
|   size_t st = header.find_last_of('(') + 1; | ||||
|   size_t ed = header.find_last_of(')'); | ||||
|   std::string shape_str = header.substr(st, ed - st); | ||||
|  | ||||
|   while (!shape_str.empty()) { | ||||
|     // Read current number and get position of comma | ||||
|     size_t pos; | ||||
|     int dim = std::stoi(shape_str, &pos); | ||||
|     shape.push_back(dim); | ||||
|  | ||||
|     // Skip the comma and space and read the next number | ||||
|     if (pos + 2 <= shape_str.length()) | ||||
|       shape_str = shape_str.substr(pos + 2); | ||||
|     else { | ||||
|       shape_str = shape_str.substr(pos); | ||||
|       if (!shape_str.empty() && shape_str != " " && shape_str != ",") { | ||||
|         throw std::runtime_error( | ||||
|             "[load] Unknown error while parsing header in " + | ||||
|             in_stream->label()); | ||||
|       } | ||||
|       shape_str = ""; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Build primitive | ||||
|  | ||||
|   size_t offset = 8 + header_len_size + header.length(); | ||||
|   bool swap_endianness = read_is_big_endian != is_big_endian_(); | ||||
|  | ||||
|   if (col_contiguous) { | ||||
|     std::reverse(shape.begin(), shape.end()); | ||||
|   } | ||||
|   auto loaded_array = array( | ||||
|       shape, | ||||
|       dtype, | ||||
|       std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness), | ||||
|       std::vector<array>{}); | ||||
|   if (col_contiguous) { | ||||
|     loaded_array = transpose(loaded_array, s); | ||||
|   } | ||||
|  | ||||
|   return loaded_array; | ||||
| } | ||||
|  | ||||
| /** Load array from file in .npy format */ | ||||
| array load(const std::string& file, StreamOrDevice s) { | ||||
|   return load(std::make_shared<io::FileReader>(file), s); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
							
								
								
									
										114
									
								
								mlx/io/load.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								mlx/io/load.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <fstream> | ||||
| #include <istream> | ||||
| #include <memory> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace io { | ||||
|  | ||||
| class Reader { | ||||
|  public: | ||||
|   virtual bool is_open() const = 0; | ||||
|   virtual bool good() const = 0; | ||||
|   virtual size_t tell() const = 0; | ||||
|   virtual void seek( | ||||
|       int64_t off, | ||||
|       std::ios_base::seekdir way = std::ios_base::beg) = 0; | ||||
|   virtual void read(char* data, size_t n) = 0; | ||||
|   virtual std::string label() const = 0; | ||||
| }; | ||||
|  | ||||
| class Writer { | ||||
|  public: | ||||
|   virtual bool is_open() const = 0; | ||||
|   virtual bool good() const = 0; | ||||
|   virtual size_t tell() const = 0; | ||||
|   virtual void seek( | ||||
|       int64_t off, | ||||
|       std::ios_base::seekdir way = std::ios_base::beg) = 0; | ||||
|   virtual void write(const char* data, size_t n) = 0; | ||||
|   virtual std::string label() const = 0; | ||||
| }; | ||||
|  | ||||
| class FileReader : public Reader { | ||||
|  public: | ||||
|   explicit FileReader(const std::shared_ptr<std::ifstream>& is) | ||||
|       : is_(is), label_("stream") {} | ||||
|   explicit FileReader(const std::string& file_path) | ||||
|       : is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)), | ||||
|         label_(file_path) {} | ||||
|  | ||||
|   bool is_open() const override { | ||||
|     return is_->is_open(); | ||||
|   } | ||||
|  | ||||
|   bool good() const override { | ||||
|     return is_->good(); | ||||
|   } | ||||
|  | ||||
|   size_t tell() const override { | ||||
|     return is_->tellg(); | ||||
|   } | ||||
|  | ||||
|   void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) | ||||
|       override { | ||||
|     is_->seekg(off, way); | ||||
|   } | ||||
|  | ||||
|   void read(char* data, size_t n) override { | ||||
|     is_->read(data, n); | ||||
|   } | ||||
|  | ||||
|   std::string label() const override { | ||||
|     return "file " + label_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<std::ifstream> is_; | ||||
|   std::string label_; | ||||
| }; | ||||
|  | ||||
| class FileWriter : public Writer { | ||||
|  public: | ||||
|   explicit FileWriter(const std::shared_ptr<std::ofstream>& is) | ||||
|       : os_(is), label_("stream") {} | ||||
|   explicit FileWriter(const std::string& file_path) | ||||
|       : os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)), | ||||
|         label_(file_path) {} | ||||
|  | ||||
|   bool is_open() const override { | ||||
|     return os_->is_open(); | ||||
|   } | ||||
|  | ||||
|   bool good() const override { | ||||
|     return os_->good(); | ||||
|   } | ||||
|  | ||||
|   size_t tell() const override { | ||||
|     return os_->tellp(); | ||||
|   } | ||||
|  | ||||
|   void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) | ||||
|       override { | ||||
|     os_->seekp(off, way); | ||||
|   } | ||||
|  | ||||
|   void write(const char* data, size_t n) override { | ||||
|     os_->write(data, n); | ||||
|   } | ||||
|  | ||||
|   std::string label() const override { | ||||
|     return "file " + label_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<std::ofstream> os_; | ||||
|   std::string label_; | ||||
| }; | ||||
|  | ||||
| } // namespace io | ||||
| } // namespace mlx::core | ||||
							
								
								
									
										189
									
								
								mlx/io/safetensor.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								mlx/io/safetensor.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | ||||
| #include "mlx/io/safetensor.h" | ||||
|  | ||||
| #include <stack> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| std::string dtype_to_safetensor_str(Dtype t) { | ||||
|   switch (t) { | ||||
|     case float32: | ||||
|       return ST_F32; | ||||
|     case bfloat16: | ||||
|       return ST_BF16; | ||||
|     case float16: | ||||
|       return ST_F16; | ||||
|     case int64: | ||||
|       return ST_I64; | ||||
|     case int32: | ||||
|       return ST_I32; | ||||
|     case int16: | ||||
|       return ST_I16; | ||||
|     case int8: | ||||
|       return ST_I8; | ||||
|     case uint64: | ||||
|       return ST_U64; | ||||
|     case uint32: | ||||
|       return ST_U32; | ||||
|     case uint16: | ||||
|       return ST_U16; | ||||
|     case uint8: | ||||
|       return ST_U8; | ||||
|     case bool_: | ||||
|       return ST_BOOL; | ||||
|     case complex64: | ||||
|       return ST_C64; | ||||
|   } | ||||
| } | ||||
|  | ||||
| Dtype dtype_from_safetensor_str(std::string str) { | ||||
|   if (str == ST_F32) { | ||||
|     return float32; | ||||
|   } else if (str == ST_F16) { | ||||
|     return float16; | ||||
|   } else if (str == ST_BF16) { | ||||
|     return bfloat16; | ||||
|   } else if (str == ST_I64) { | ||||
|     return int64; | ||||
|   } else if (str == ST_I32) { | ||||
|     return int32; | ||||
|   } else if (str == ST_I16) { | ||||
|     return int16; | ||||
|   } else if (str == ST_I8) { | ||||
|     return int8; | ||||
|   } else if (str == ST_U64) { | ||||
|     return uint64; | ||||
|   } else if (str == ST_U32) { | ||||
|     return uint32; | ||||
|   } else if (str == ST_U16) { | ||||
|     return uint16; | ||||
|   } else if (str == ST_U8) { | ||||
|     return uint8; | ||||
|   } else if (str == ST_BOOL) { | ||||
|     return bool_; | ||||
|   } else if (str == ST_C64) { | ||||
|     return complex64; | ||||
|   } else { | ||||
|     throw std::runtime_error("[safetensor] unsupported dtype " + str); | ||||
|   } | ||||
| } | ||||
|  | ||||
| /** Load array from reader in safetensor format */ | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
|     std::shared_ptr<io::Reader> in_stream, | ||||
|     StreamOrDevice s) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Open and check file | ||||
|   if (!in_stream->good() || !in_stream->is_open()) { | ||||
|     throw std::runtime_error( | ||||
|         "[load_safetensors] Failed to open " + in_stream->label()); | ||||
|   } | ||||
|  | ||||
|   uint64_t jsonHeaderLength = 0; | ||||
|   in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8); | ||||
|   if (jsonHeaderLength <= 0) { | ||||
|     throw std::runtime_error( | ||||
|         "[load_safetensors] Invalid json header length " + in_stream->label()); | ||||
|   } | ||||
|   // Load the json metadata | ||||
|   char rawJson[jsonHeaderLength]; | ||||
|   in_stream->read(rawJson, jsonHeaderLength); | ||||
|   auto metadata = json::parse(rawJson, rawJson + jsonHeaderLength); | ||||
|   // Should always be an object on the top-level | ||||
|   if (!metadata.is_object()) { | ||||
|     throw std::runtime_error( | ||||
|         "[load_safetensors] Invalid json metadata " + in_stream->label()); | ||||
|   } | ||||
|   size_t offset = jsonHeaderLength + 8; | ||||
|   // Load the arrays using metadata | ||||
|   std::unordered_map<std::string, array> res; | ||||
|   for (const auto& item : metadata.items()) { | ||||
|     if (item.key() == "__metadata__") { | ||||
|       // ignore metadata for now | ||||
|       continue; | ||||
|     } | ||||
|     std::string dtype = item.value().at("dtype"); | ||||
|     std::vector<int> shape = item.value().at("shape"); | ||||
|     std::vector<size_t> data_offsets = item.value().at("data_offsets"); | ||||
|     Dtype type = dtype_from_safetensor_str(dtype); | ||||
|     auto loaded_array = array( | ||||
|         shape, | ||||
|         type, | ||||
|         std::make_unique<Load>( | ||||
|             to_stream(s), in_stream, offset + data_offsets.at(0), false), | ||||
|         std::vector<array>{}); | ||||
|     res.insert({item.key(), loaded_array}); | ||||
|   } | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| std::unordered_map<std::string, array> load_safetensors( | ||||
|     const std::string& file, | ||||
|     StreamOrDevice s) { | ||||
|   return load_safetensors(std::make_shared<io::FileReader>(file), s); | ||||
| } | ||||
|  | ||||
| /** Save array to out stream in .npy format */ | ||||
| void save_safetensors( | ||||
|     std::shared_ptr<io::Writer> out_stream, | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::optional<bool> retain_graph_) { | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check file | ||||
|   if (!out_stream->good() || !out_stream->is_open()) { | ||||
|     throw std::runtime_error( | ||||
|         "[save_safetensors] Failed to open " + out_stream->label()); | ||||
|   } | ||||
|  | ||||
|   //////////////////////////////////////////////////////// | ||||
|   // Check array map | ||||
|   json parent; | ||||
|   parent["__metadata__"] = json::object({ | ||||
|       {"format", "mlx"}, | ||||
|   }); | ||||
|   size_t offset = 0; | ||||
|   for (auto& [key, arr] : a) { | ||||
|     arr.eval(retain_graph_.value_or(arr.is_tracer())); | ||||
|     if (arr.nbytes() == 0) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] cannot serialize an empty array key: " + key); | ||||
|     } | ||||
|  | ||||
|     if (!arr.flags().contiguous) { | ||||
|       throw std::invalid_argument( | ||||
|           "[save_safetensors] cannot serialize a non-contiguous array key: " + | ||||
|           key); | ||||
|     } | ||||
|     json child; | ||||
|     child["dtype"] = dtype_to_safetensor_str(arr.dtype()); | ||||
|     child["shape"] = arr.shape(); | ||||
|     child["data_offsets"] = std::vector<size_t>{offset, offset + arr.nbytes()}; | ||||
|     parent[key] = child; | ||||
|     offset += arr.nbytes(); | ||||
|   } | ||||
|  | ||||
|   auto header = parent.dump(); | ||||
|   uint64_t header_len = header.length(); | ||||
|   out_stream->write(reinterpret_cast<char*>(&header_len), 8); | ||||
|   out_stream->write(header.c_str(), header_len); | ||||
|   for (auto& [key, arr] : a) { | ||||
|     out_stream->write(arr.data<char>(), arr.nbytes()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void save_safetensors( | ||||
|     const std::string& file_, | ||||
|     std::unordered_map<std::string, array> a, | ||||
|     std::optional<bool> retain_graph) { | ||||
|   // Open and check file | ||||
|   std::string file = file_; | ||||
|  | ||||
|   // Add .safetensors to file name if it is not there | ||||
|   if (file.length() < 12 || | ||||
|       file.substr(file.length() - 12, 12) != ".safetensors") | ||||
|     file += ".safetensors"; | ||||
|  | ||||
|   // Serialize array | ||||
|   save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
							
								
								
									
										32
									
								
								mlx/io/safetensor.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								mlx/io/safetensor.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <json.hpp> | ||||
|  | ||||
| #include "mlx/io/load.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| using json = nlohmann::json; | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| #define ST_F16 "F16" | ||||
| #define ST_BF16 "BF16" | ||||
| #define ST_F32 "F32" | ||||
|  | ||||
| #define ST_BOOL "BOOL" | ||||
| #define ST_I8 "I8" | ||||
| #define ST_I16 "I16" | ||||
| #define ST_I32 "I32" | ||||
| #define ST_I64 "I64" | ||||
| #define ST_U8 "U8" | ||||
| #define ST_U16 "U16" | ||||
| #define ST_U32 "U32" | ||||
| #define ST_U64 "U64" | ||||
|  | ||||
| // Note: Complex numbers aren't in the spec yet so this could change - | ||||
| // https://github.com/huggingface/safetensors/issues/389 | ||||
| #define ST_C64 "C64" | ||||
| } // namespace mlx::core | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo