diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 279fb03287..f5082010cc 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -370,7 +370,7 @@ void compile_simplify( auto get_scalar_rep = [](const array& a) { uint64_t v = 0; int dtype; - switch (a.dtype().size) { + switch (a.dtype().size()) { case 1: v = *a.data(); break; @@ -384,7 +384,7 @@ void compile_simplify( v = *a.data(); break; } - return std::make_pair(v, a.dtype().val); + return std::make_pair(v, a.dtype().val()); }; for (auto& a : tape) { diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index cf78a7805a..665512f04b 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = { } // namespace Dtype promote_types(const Dtype& t1, const Dtype& t2) { - return Dtype(type_rules[static_cast(t1.val)][static_cast(t2.val)]); + return Dtype( + type_rules[static_cast(t1.val())][static_cast(t2.val())]); } Dtype::Kind kindof(const Dtype& t) { - return type_kinds[static_cast(t.val)]; + return type_kinds[static_cast(t.val())]; } template <> @@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) { } bool issubdtype(const Dtype& type, const Dtype::Category& cat) { - return issubdtype(type_to_category[static_cast(type.val)], cat); + return issubdtype(type_to_category[static_cast(type.val())], cat); } bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { diff --git a/mlx/dtype.h b/mlx/dtype.h index 5f9ee27cc1..11d61e378b 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -47,12 +47,21 @@ struct Dtype { generic }; - Val val; - const uint8_t size; - constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {} + constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {} + constexpr operator Val() const { - return val; + return val_; } + constexpr Val val() const { + return val_; + } + constexpr uint8_t size() const { + return size_; + } + + private: + Val val_; + uint8_t size_; }; inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; @@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); Dtype promote_types(const Dtype& t1, const Dtype& t2); inline uint8_t size_of(const Dtype& t) { - return t.size; + return t.size(); } Dtype::Kind kindof(const Dtype& t); diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index be1f382bae..dffb2aa1d7 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -64,7 +64,7 @@ std::tuple extract_tensor_data(gguf_tensor* tensor) { memcpy( buffer.raw_ptr(), tensor->weights_data, - tensor->num_weights * equivalent_dtype.value().size); + tensor->num_weights * equivalent_dtype.value().size()); return {buffer, equivalent_dtype.value()}; } // Otherwise, we convert to float16. diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 4f89b22784..8e6a5b2f9e 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -120,7 +120,7 @@ void gguf_load_quantized( std::vector weights_shape = shape; weights_shape.back() /= (weights_per_byte * 4); - auto w_nbytes = uint32.size * + auto w_nbytes = uint32.size() * std::accumulate(weights_shape.begin(), weights_shape.end(), 1, @@ -130,7 +130,7 @@ void gguf_load_quantized( // For scales and bias shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block; - auto sb_nbytes = float16.size * + auto sb_nbytes = float16.size() * std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); array scales(allocator::malloc(sb_nbytes), shape, float16); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 34882884ba..e3c2c72bd7 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { PrintFormatter global_formatter; Dtype result_type(const std::vector& arrays) { - std::vector dtypes(1, bool_); + Dtype t = bool_; for (auto& arr : arrays) { - dtypes.push_back(promote_types(dtypes.back(), arr.dtype())); + t = promote_types(t, arr.dtype()); } - return dtypes.back(); + return t; } std::vector broadcast_shapes( diff --git a/python/src/array.cpp b/python/src/array.cpp index 71ecd658fc..c9bc8eebeb 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -97,7 +97,8 @@ void init_array(nb::module_& m) { See the :ref:`list of types ` for more details on available data types. )pbdoc") - .def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") + .def_prop_ro( + "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") .def( "__repr__", [](const Dtype& t) { @@ -112,7 +113,7 @@ void init_array(nb::module_& m) { return nb::isinstance(other) && t == nb::cast(other); }) .def("__hash__", [](const Dtype& t) { - return static_cast(t.val); + return static_cast(t.val()); }); m.attr("bool_") = nb::cast(bool_);