mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
dtype is copy assignable (#1436)
This commit is contained in:
@@ -370,7 +370,7 @@ void compile_simplify(
|
||||
auto get_scalar_rep = [](const array& a) {
|
||||
uint64_t v = 0;
|
||||
int dtype;
|
||||
switch (a.dtype().size) {
|
||||
switch (a.dtype().size()) {
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
@@ -384,7 +384,7 @@ void compile_simplify(
|
||||
v = *a.data<uint64_t>();
|
||||
break;
|
||||
}
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
return std::make_pair(v, a.dtype().val());
|
||||
};
|
||||
|
||||
for (auto& a : tape) {
|
||||
|
@@ -81,11 +81,12 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
||||
} // namespace
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||
return Dtype(type_rules[static_cast<int>(t1.val)][static_cast<int>(t2.val)]);
|
||||
return Dtype(
|
||||
type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
|
||||
}
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t) {
|
||||
return type_kinds[static_cast<int>(t.val)];
|
||||
return type_kinds[static_cast<int>(t.val())];
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -167,7 +168,7 @@ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
|
19
mlx/dtype.h
19
mlx/dtype.h
@@ -47,12 +47,21 @@ struct Dtype {
|
||||
generic
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
|
||||
|
||||
constexpr operator Val() const {
|
||||
return val;
|
||||
return val_;
|
||||
}
|
||||
constexpr Val val() const {
|
||||
return val_;
|
||||
}
|
||||
constexpr uint8_t size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
Val val_;
|
||||
uint8_t size_;
|
||||
};
|
||||
|
||||
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
@@ -91,7 +100,7 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
inline uint8_t size_of(const Dtype& t) {
|
||||
return t.size;
|
||||
return t.size();
|
||||
}
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
@@ -64,7 +64,7 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
|
||||
memcpy(
|
||||
buffer.raw_ptr(),
|
||||
tensor->weights_data,
|
||||
tensor->num_weights * equivalent_dtype.value().size);
|
||||
tensor->num_weights * equivalent_dtype.value().size());
|
||||
return {buffer, equivalent_dtype.value()};
|
||||
}
|
||||
// Otherwise, we convert to float16.
|
||||
|
@@ -120,7 +120,7 @@ void gguf_load_quantized(
|
||||
|
||||
std::vector<int> weights_shape = shape;
|
||||
weights_shape.back() /= (weights_per_byte * 4);
|
||||
auto w_nbytes = uint32.size *
|
||||
auto w_nbytes = uint32.size() *
|
||||
std::accumulate(weights_shape.begin(),
|
||||
weights_shape.end(),
|
||||
1,
|
||||
@@ -130,7 +130,7 @@ void gguf_load_quantized(
|
||||
|
||||
// For scales and bias
|
||||
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
|
||||
auto sb_nbytes = float16.size *
|
||||
auto sb_nbytes = float16.size() *
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
array scales(allocator::malloc(sb_nbytes), shape, float16);
|
||||
|
@@ -58,11 +58,11 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||
PrintFormatter global_formatter;
|
||||
|
||||
Dtype result_type(const std::vector<array>& arrays) {
|
||||
std::vector<Dtype> dtypes(1, bool_);
|
||||
Dtype t = bool_;
|
||||
for (auto& arr : arrays) {
|
||||
dtypes.push_back(promote_types(dtypes.back(), arr.dtype()));
|
||||
t = promote_types(t, arr.dtype());
|
||||
}
|
||||
return dtypes.back();
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<int> broadcast_shapes(
|
||||
|
Reference in New Issue
Block a user