dtype is copy assignable (#1436)

This commit is contained in:
Awni Hannun 2024-09-25 12:07:13 -07:00 committed by GitHub
parent 195b429d99
commit afc9c0ec1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 29 additions and 18 deletions

View File

@ -370,7 +370,7 @@ void compile_simplify(
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
switch (a.dtype().size()) {
case 1:
v = *a.data<uint8_t>();
break;
@ -384,7 +384,7 @@ void compile_simplify(
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
return std::make_pair(v, a.dtype().val());
};
for (auto& a : tape) {

View File

@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = {
} // namespace
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
return Dtype(
type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
}
Dtype::Kind kindof(const Dtype& t) {
return type_kinds[static_cast<int>(t.val)];
return type_kinds[static_cast<int>(t.val())];
}
template <>
@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
}
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
}
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {

View File

@ -47,12 +47,21 @@ struct Dtype {
generic
};
Val val;
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
constexpr operator Val() const {
return val;
return val_;
}
constexpr Val val() const {
return val_;
}
constexpr uint8_t size() const {
return size_;
}
private:
Val val_;
uint8_t size_;
};
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) {
return t.size;
return t.size();
}
Dtype::Kind kindof(const Dtype& t);

View File

@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
memcpy(
buffer.raw_ptr(),
tensor->weights_data,
tensor->num_weights * equivalent_dtype.value().size);
tensor->num_weights * equivalent_dtype.value().size());
return {buffer, equivalent_dtype.value()};
}
// Otherwise, we convert to float16.

View File

@ -120,7 +120,7 @@ void gguf_load_quantized(
std::vector<int> weights_shape = shape;
weights_shape.back() /= (weights_per_byte * 4);
auto w_nbytes = uint32.size *
auto w_nbytes = uint32.size() *
std::accumulate(weights_shape.begin(),
weights_shape.end(),
1,
@ -130,7 +130,7 @@ void gguf_load_quantized(
// For scales and bias
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
auto sb_nbytes = float16.size *
auto sb_nbytes = float16.size() *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
array scales(allocator::malloc(sb_nbytes), shape, float16);

View File

@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
PrintFormatter global_formatter;
Dtype result_type(const std::vector<array>& arrays) {
std::vector<Dtype> dtypes(1, bool_);
Dtype t = bool_;
for (auto& arr : arrays) {
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
t = promote_types(t, arr.dtype());
}
return dtypes.back();
return t;
}
std::vector<int> broadcast_shapes(

View File

@ -97,7 +97,8 @@ void init_array(nb::module_& m) {
See the :ref:`list of types <data_types>` for more details
on available data types.
)pbdoc")
.def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
.def_prop_ro(
"size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
.def(
"__repr__",
[](const Dtype& t) {
@ -112,7 +113,7 @@ void init_array(nb::module_& m) {
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
})
.def("__hash__", [](const Dtype& t) {
return static_cast<int64_t>(t.val);
return static_cast<int64_t>(t.val());
});
m.attr("bool_") = nb::cast(bool_);