mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-13 12:46:40 +08:00
dtype is copy assignable (#1436)
This commit is contained in:
parent
195b429d99
commit
afc9c0ec1b
@ -370,7 +370,7 @@ void compile_simplify(
|
||||
auto get_scalar_rep = [](const array& a) {
|
||||
uint64_t v = 0;
|
||||
int dtype;
|
||||
switch (a.dtype().size) {
|
||||
switch (a.dtype().size()) {
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
@ -384,7 +384,7 @@ void compile_simplify(
|
||||
v = *a.data<uint64_t>();
|
||||
break;
|
||||
}
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
return std::make_pair(v, a.dtype().val());
|
||||
};
|
||||
|
||||
for (auto& a : tape) {
|
||||
|
@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
||||
} // namespace
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
|
||||
return Dtype(
|
||||
type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
|
||||
}
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t) {
|
||||
return type_kinds[static_cast<int>(t.val)];
|
||||
return type_kinds[static_cast<int>(t.val())];
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
|
19
mlx/dtype.h
19
mlx/dtype.h
@ -47,12 +47,21 @@ struct Dtype {
|
||||
generic
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
|
||||
|
||||
constexpr operator Val() const {
|
||||
return val;
|
||||
return val_;
|
||||
}
|
||||
constexpr Val val() const {
|
||||
return val_;
|
||||
}
|
||||
constexpr uint8_t size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
Val val_;
|
||||
uint8_t size_;
|
||||
};
|
||||
|
||||
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
inline uint8_t size_of(const Dtype& t) {
|
||||
return t.size;
|
||||
return t.size();
|
||||
}
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
memcpy(
|
||||
buffer.raw_ptr(),
|
||||
tensor->weights_data,
|
||||
tensor->num_weights * equivalent_dtype.value().size);
|
||||
tensor->num_weights * equivalent_dtype.value().size());
|
||||
return {buffer, equivalent_dtype.value()};
|
||||
}
|
||||
// Otherwise, we convert to float16.
|
||||
|
@ -120,7 +120,7 @@ void gguf_load_quantized(
|
||||
|
||||
std::vector<int> weights_shape = shape;
|
||||
weights_shape.back() /= (weights_per_byte * 4);
|
||||
auto w_nbytes = uint32.size *
|
||||
auto w_nbytes = uint32.size() *
|
||||
std::accumulate(weights_shape.begin(),
|
||||
weights_shape.end(),
|
||||
1,
|
||||
@ -130,7 +130,7 @@ void gguf_load_quantized(
|
||||
|
||||
// For scales and bias
|
||||
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
|
||||
auto sb_nbytes = float16.size *
|
||||
auto sb_nbytes = float16.size() *
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
array scales(allocator::malloc(sb_nbytes), shape, float16);
|
||||
|
@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||
PrintFormatter global_formatter;
|
||||
|
||||
Dtype result_type(const std::vector<array>& arrays) {
|
||||
std::vector<Dtype> dtypes(1, bool_);
|
||||
Dtype t = bool_;
|
||||
for (auto& arr : arrays) {
|
||||
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
|
||||
t = promote_types(t, arr.dtype());
|
||||
}
|
||||
return dtypes.back();
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<int> broadcast_shapes(
|
||||
|
@ -97,7 +97,8 @@ void init_array(nb::module_& m) {
|
||||
See the :ref:`list of types <data_types>` for more details
|
||||
on available data types.
|
||||
)pbdoc")
|
||||
.def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
||||
.def_prop_ro(
|
||||
"size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Dtype& t) {
|
||||
@ -112,7 +113,7 @@ void init_array(nb::module_& m) {
|
||||
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
|
||||
})
|
||||
.def("__hash__", [](const Dtype& t) {
|
||||
return static_cast<int64_t>(t.val);
|
||||
return static_cast<int64_t>(t.val());
|
||||
});
|
||||
|
||||
m.attr("bool_") = nb::cast(bool_);
|
||||
|
Loading…
Reference in New Issue
Block a user