dtype is copy assignable (#1436)

This commit is contained in:
Awni Hannun 2024-09-25 12:07:13 -07:00 committed by GitHub
parent 195b429d99
commit afc9c0ec1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 29 additions and 18 deletions

View File

@ -370,7 +370,7 @@ void compile_simplify(
auto get_scalar_rep = [](const array& a) { auto get_scalar_rep = [](const array& a) {
uint64_t v = 0; uint64_t v = 0;
int dtype; int dtype;
switch (a.dtype().size) { switch (a.dtype().size()) {
case 1: case 1:
v = *a.data<uint8_t>(); v = *a.data<uint8_t>();
break; break;
@ -384,7 +384,7 @@ void compile_simplify(
v = *a.data<uint64_t>(); v = *a.data<uint64_t>();
break; break;
} }
return std::make_pair(v, a.dtype().val); return std::make_pair(v, a.dtype().val());
}; };
for (auto& a : tape) { for (auto& a : tape) {

View File

@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = {
} // namespace } // namespace
Dtype promote_types(const Dtype& t1, const Dtype& t2) { Dtype promote_types(const Dtype& t1, const Dtype& t2) {
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]); return Dtype(
type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
} }
Dtype::Kind kindof(const Dtype& t) { Dtype::Kind kindof(const Dtype& t) {
return type_kinds[static_cast<int>(t.val)]; return type_kinds[static_cast<int>(t.val())];
} }
template <> template <>
@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
} }
bool issubdtype(const Dtype& type, const Dtype::Category& cat) { bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat); return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
} }
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {

View File

@ -47,12 +47,21 @@ struct Dtype {
generic generic
}; };
Val val; constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
constexpr operator Val() const { constexpr operator Val() const {
return val; return val_;
} }
constexpr Val val() const {
return val_;
}
constexpr uint8_t size() const {
return size_;
}
private:
Val val_;
uint8_t size_;
}; };
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
Dtype promote_types(const Dtype& t1, const Dtype& t2); Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) { inline uint8_t size_of(const Dtype& t) {
return t.size; return t.size();
} }
Dtype::Kind kindof(const Dtype& t); Dtype::Kind kindof(const Dtype& t);

View File

@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
memcpy( memcpy(
buffer.raw_ptr(), buffer.raw_ptr(),
tensor->weights_data, tensor->weights_data,
tensor->num_weights * equivalent_dtype.value().size); tensor->num_weights * equivalent_dtype.value().size());
return {buffer, equivalent_dtype.value()}; return {buffer, equivalent_dtype.value()};
} }
// Otherwise, we convert to float16. // Otherwise, we convert to float16.

View File

@ -120,7 +120,7 @@ void gguf_load_quantized(
std::vector<int> weights_shape = shape; std::vector<int> weights_shape = shape;
weights_shape.back() /= (weights_per_byte * 4); weights_shape.back() /= (weights_per_byte * 4);
auto w_nbytes = uint32.size * auto w_nbytes = uint32.size() *
std::accumulate(weights_shape.begin(), std::accumulate(weights_shape.begin(),
weights_shape.end(), weights_shape.end(),
1, 1,
@ -130,7 +130,7 @@ void gguf_load_quantized(
// For scales and bias // For scales and bias
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block; shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
auto sb_nbytes = float16.size * auto sb_nbytes = float16.size() *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
array scales(allocator::malloc(sb_nbytes), shape, float16); array scales(allocator::malloc(sb_nbytes), shape, float16);

View File

@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
PrintFormatter global_formatter; PrintFormatter global_formatter;
Dtype result_type(const std::vector<array>& arrays) { Dtype result_type(const std::vector<array>& arrays) {
std::vector<Dtype> dtypes(1, bool_); Dtype t = bool_;
for (auto& arr : arrays) { for (auto& arr : arrays) {
dtypes.push_back(promote_types(dtypes.back(), arr.dtype())); t = promote_types(t, arr.dtype());
} }
return dtypes.back(); return t;
} }
std::vector<int> broadcast_shapes( std::vector<int> broadcast_shapes(

View File

@ -97,7 +97,8 @@ void init_array(nb::module_& m) {
See the :ref:`list of types <data_types>` for more details See the :ref:`list of types <data_types>` for more details
on available data types. on available data types.
)pbdoc") )pbdoc")
.def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") .def_prop_ro(
"size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
.def( .def(
"__repr__", "__repr__",
[](const Dtype& t) { [](const Dtype& t) {
@ -112,7 +113,7 @@ void init_array(nb::module_& m) {
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other); return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
}) })
.def("__hash__", [](const Dtype& t) { .def("__hash__", [](const Dtype& t) {
return static_cast<int64_t>(t.val); return static_cast<int64_t>(t.val());
}); });
m.attr("bool_") = nb::cast(bool_); m.attr("bool_") = nb::cast(bool_);