mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
iinfo and scalar overflow detection (#2009)
This commit is contained in:
@@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
||||
min = std::numeric_limits<T>::min();
|
||||
max = std::numeric_limits<T>::max();
|
||||
}
|
||||
|
||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
set_iinfo_limits<int8_t>(min, max);
|
||||
break;
|
||||
case uint8:
|
||||
set_iinfo_limits<uint8_t>(min, max);
|
||||
break;
|
||||
case int16:
|
||||
set_iinfo_limits<int16_t>(min, max);
|
||||
break;
|
||||
case uint16:
|
||||
set_iinfo_limits<uint16_t>(min, max);
|
||||
break;
|
||||
case int32:
|
||||
set_iinfo_limits<int32_t>(min, max);
|
||||
break;
|
||||
case uint32:
|
||||
set_iinfo_limits<uint32_t>(min, max);
|
||||
break;
|
||||
case int64:
|
||||
set_iinfo_limits<int64_t>(min, max);
|
||||
break;
|
||||
case uint64:
|
||||
set_iinfo_limits<uint64_t>(min, max);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << "[iinfo] dtype " << dtype << " is not integral.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -67,6 +67,14 @@ struct finfo {
|
||||
double max;
|
||||
};
|
||||
|
||||
/** Holds information about integral types. */
|
||||
struct iinfo {
|
||||
explicit iinfo(Dtype dtype);
|
||||
Dtype dtype;
|
||||
int64_t min;
|
||||
uint64_t max;
|
||||
};
|
||||
|
||||
/** The type from promoting the arrays' types with one another. */
|
||||
inline Dtype result_type(const array& a, const array& b) {
|
||||
return promote_types(a.dtype(), b.dtype());
|
||||
|
||||
Reference in New Issue
Block a user