mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	dtype is copy assignable (#1436)
This commit is contained in:
		| @@ -370,7 +370,7 @@ void compile_simplify( | ||||
|   auto get_scalar_rep = [](const array& a) { | ||||
|     uint64_t v = 0; | ||||
|     int dtype; | ||||
|     switch (a.dtype().size) { | ||||
|     switch (a.dtype().size()) { | ||||
|       case 1: | ||||
|         v = *a.data<uint8_t>(); | ||||
|         break; | ||||
| @@ -384,7 +384,7 @@ void compile_simplify( | ||||
|         v = *a.data<uint64_t>(); | ||||
|         break; | ||||
|     } | ||||
|     return std::make_pair(v, a.dtype().val); | ||||
|     return std::make_pair(v, a.dtype().val()); | ||||
|   }; | ||||
|  | ||||
|   for (auto& a : tape) { | ||||
|   | ||||
| @@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = { | ||||
| } // namespace | ||||
|  | ||||
| Dtype promote_types(const Dtype& t1, const Dtype& t2) { | ||||
|   return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]); | ||||
|   return Dtype( | ||||
|       type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]); | ||||
| } | ||||
|  | ||||
| Dtype::Kind kindof(const Dtype& t) { | ||||
|   return type_kinds[static_cast<int>(t.val)]; | ||||
|   return type_kinds[static_cast<int>(t.val())]; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) { | ||||
| } | ||||
|  | ||||
| bool issubdtype(const Dtype& type, const Dtype::Category& cat) { | ||||
|   return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat); | ||||
|   return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat); | ||||
| } | ||||
|  | ||||
| bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { | ||||
|   | ||||
							
								
								
									
										19
									
								
								mlx/dtype.h
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								mlx/dtype.h
									
									
									
									
									
								
							| @@ -47,12 +47,21 @@ struct Dtype { | ||||
|     generic | ||||
|   }; | ||||
|  | ||||
|   Val val; | ||||
|   const uint8_t size; | ||||
|   constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {} | ||||
|   constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {} | ||||
|  | ||||
|   constexpr operator Val() const { | ||||
|     return val; | ||||
|     return val_; | ||||
|   } | ||||
|   constexpr Val val() const { | ||||
|     return val_; | ||||
|   } | ||||
|   constexpr uint8_t size() const { | ||||
|     return size_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   Val val_; | ||||
|   uint8_t size_; | ||||
| }; | ||||
|  | ||||
| inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; | ||||
| @@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); | ||||
| Dtype promote_types(const Dtype& t1, const Dtype& t2); | ||||
|  | ||||
| inline uint8_t size_of(const Dtype& t) { | ||||
|   return t.size; | ||||
|   return t.size(); | ||||
| } | ||||
|  | ||||
| Dtype::Kind kindof(const Dtype& t); | ||||
|   | ||||
| @@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) { | ||||
|     memcpy( | ||||
|         buffer.raw_ptr(), | ||||
|         tensor->weights_data, | ||||
|         tensor->num_weights * equivalent_dtype.value().size); | ||||
|         tensor->num_weights * equivalent_dtype.value().size()); | ||||
|     return {buffer, equivalent_dtype.value()}; | ||||
|   } | ||||
|   // Otherwise, we convert to float16. | ||||
|   | ||||
| @@ -120,7 +120,7 @@ void gguf_load_quantized( | ||||
|  | ||||
|   std::vector<int> weights_shape = shape; | ||||
|   weights_shape.back() /= (weights_per_byte * 4); | ||||
|   auto w_nbytes = uint32.size * | ||||
|   auto w_nbytes = uint32.size() * | ||||
|       std::accumulate(weights_shape.begin(), | ||||
|                       weights_shape.end(), | ||||
|                       1, | ||||
| @@ -130,7 +130,7 @@ void gguf_load_quantized( | ||||
|  | ||||
|   // For scales and bias | ||||
|   shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block; | ||||
|   auto sb_nbytes = float16.size * | ||||
|   auto sb_nbytes = float16.size() * | ||||
|       std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()); | ||||
|  | ||||
|   array scales(allocator::malloc(sb_nbytes), shape, float16); | ||||
|   | ||||
| @@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { | ||||
| PrintFormatter global_formatter; | ||||
|  | ||||
| Dtype result_type(const std::vector<array>& arrays) { | ||||
|   std::vector<Dtype> dtypes(1, bool_); | ||||
|   Dtype t = bool_; | ||||
|   for (auto& arr : arrays) { | ||||
|     dtypes.push_back(promote_types(dtypes.back(), arr.dtype())); | ||||
|     t = promote_types(t, arr.dtype()); | ||||
|   } | ||||
|   return dtypes.back(); | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| std::vector<int> broadcast_shapes( | ||||
|   | ||||
| @@ -97,7 +97,8 @@ void init_array(nb::module_& m) { | ||||
|       See the :ref:`list of types <data_types>` for more details | ||||
|       on available data types. | ||||
|       )pbdoc") | ||||
|       .def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") | ||||
|       .def_prop_ro( | ||||
|           "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") | ||||
|       .def( | ||||
|           "__repr__", | ||||
|           [](const Dtype& t) { | ||||
| @@ -112,7 +113,7 @@ void init_array(nb::module_& m) { | ||||
|             return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other); | ||||
|           }) | ||||
|       .def("__hash__", [](const Dtype& t) { | ||||
|         return static_cast<int64_t>(t.val); | ||||
|         return static_cast<int64_t>(t.val()); | ||||
|       }); | ||||
|  | ||||
|   m.attr("bool_") = nb::cast(bool_); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun