mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-14 21:36:50 +08:00
dtype is copy assignable (#1436)
This commit is contained in:
parent
195b429d99
commit
afc9c0ec1b
@ -370,7 +370,7 @@ void compile_simplify(
|
|||||||
auto get_scalar_rep = [](const array& a) {
|
auto get_scalar_rep = [](const array& a) {
|
||||||
uint64_t v = 0;
|
uint64_t v = 0;
|
||||||
int dtype;
|
int dtype;
|
||||||
switch (a.dtype().size) {
|
switch (a.dtype().size()) {
|
||||||
case 1:
|
case 1:
|
||||||
v = *a.data<uint8_t>();
|
v = *a.data<uint8_t>();
|
||||||
break;
|
break;
|
||||||
@ -384,7 +384,7 @@ void compile_simplify(
|
|||||||
v = *a.data<uint64_t>();
|
v = *a.data<uint64_t>();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return std::make_pair(v, a.dtype().val);
|
return std::make_pair(v, a.dtype().val());
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto& a : tape) {
|
for (auto& a : tape) {
|
||||||
|
@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||||
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
|
return Dtype(
|
||||||
|
type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Dtype::Kind kindof(const Dtype& t) {
|
Dtype::Kind kindof(const Dtype& t) {
|
||||||
return type_kinds[static_cast<int>(t.val)];
|
return type_kinds[static_cast<int>(t.val())];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||||
|
19
mlx/dtype.h
19
mlx/dtype.h
@ -47,12 +47,21 @@ struct Dtype {
|
|||||||
generic
|
generic
|
||||||
};
|
};
|
||||||
|
|
||||||
Val val;
|
constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
|
||||||
const uint8_t size;
|
|
||||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
|
|
||||||
constexpr operator Val() const {
|
constexpr operator Val() const {
|
||||||
return val;
|
return val_;
|
||||||
}
|
}
|
||||||
|
constexpr Val val() const {
|
||||||
|
return val_;
|
||||||
|
}
|
||||||
|
constexpr uint8_t size() const {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Val val_;
|
||||||
|
uint8_t size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||||
@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
|||||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||||
|
|
||||||
inline uint8_t size_of(const Dtype& t) {
|
inline uint8_t size_of(const Dtype& t) {
|
||||||
return t.size;
|
return t.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
Dtype::Kind kindof(const Dtype& t);
|
Dtype::Kind kindof(const Dtype& t);
|
||||||
|
@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
|||||||
memcpy(
|
memcpy(
|
||||||
buffer.raw_ptr(),
|
buffer.raw_ptr(),
|
||||||
tensor->weights_data,
|
tensor->weights_data,
|
||||||
tensor->num_weights * equivalent_dtype.value().size);
|
tensor->num_weights * equivalent_dtype.value().size());
|
||||||
return {buffer, equivalent_dtype.value()};
|
return {buffer, equivalent_dtype.value()};
|
||||||
}
|
}
|
||||||
// Otherwise, we convert to float16.
|
// Otherwise, we convert to float16.
|
||||||
|
@ -120,7 +120,7 @@ void gguf_load_quantized(
|
|||||||
|
|
||||||
std::vector<int> weights_shape = shape;
|
std::vector<int> weights_shape = shape;
|
||||||
weights_shape.back() /= (weights_per_byte * 4);
|
weights_shape.back() /= (weights_per_byte * 4);
|
||||||
auto w_nbytes = uint32.size *
|
auto w_nbytes = uint32.size() *
|
||||||
std::accumulate(weights_shape.begin(),
|
std::accumulate(weights_shape.begin(),
|
||||||
weights_shape.end(),
|
weights_shape.end(),
|
||||||
1,
|
1,
|
||||||
@ -130,7 +130,7 @@ void gguf_load_quantized(
|
|||||||
|
|
||||||
// For scales and bias
|
// For scales and bias
|
||||||
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
|
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
|
||||||
auto sb_nbytes = float16.size *
|
auto sb_nbytes = float16.size() *
|
||||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||||
|
|
||||||
array scales(allocator::malloc(sb_nbytes), shape, float16);
|
array scales(allocator::malloc(sb_nbytes), shape, float16);
|
||||||
|
@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
|||||||
PrintFormatter global_formatter;
|
PrintFormatter global_formatter;
|
||||||
|
|
||||||
Dtype result_type(const std::vector<array>& arrays) {
|
Dtype result_type(const std::vector<array>& arrays) {
|
||||||
std::vector<Dtype> dtypes(1, bool_);
|
Dtype t = bool_;
|
||||||
for (auto& arr : arrays) {
|
for (auto& arr : arrays) {
|
||||||
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
|
t = promote_types(t, arr.dtype());
|
||||||
}
|
}
|
||||||
return dtypes.back();
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> broadcast_shapes(
|
std::vector<int> broadcast_shapes(
|
||||||
|
@ -97,7 +97,8 @@ void init_array(nb::module_& m) {
|
|||||||
See the :ref:`list of types <data_types>` for more details
|
See the :ref:`list of types <data_types>` for more details
|
||||||
on available data types.
|
on available data types.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
.def_prop_ro(
|
||||||
|
"size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
[](const Dtype& t) {
|
[](const Dtype& t) {
|
||||||
@ -112,7 +113,7 @@ void init_array(nb::module_& m) {
|
|||||||
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
|
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
|
||||||
})
|
})
|
||||||
.def("__hash__", [](const Dtype& t) {
|
.def("__hash__", [](const Dtype& t) {
|
||||||
return static_cast<int64_t>(t.val);
|
return static_cast<int64_t>(t.val());
|
||||||
});
|
});
|
||||||
|
|
||||||
m.attr("bool_") = nb::cast(bool_);
|
m.attr("bool_") = nb::cast(bool_);
|
||||||
|
Loading…
Reference in New Issue
Block a user