iinfo and scalar overflow detection (#2009)

This commit is contained in:
Awni Hannun
2025-03-27 19:54:56 -07:00
committed by GitHub
parent bc62932984
commit 5580b47291
6 changed files with 112 additions and 0 deletions

View File

@@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
}
}
template <typename T>
void set_iinfo_limits(int64_t& min, uint64_t& max) {
min = std::numeric_limits<T>::min();
max = std::numeric_limits<T>::max();
}
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
switch (dtype) {
case int8:
set_iinfo_limits<int8_t>(min, max);
break;
case uint8:
set_iinfo_limits<uint8_t>(min, max);
break;
case int16:
set_iinfo_limits<int16_t>(min, max);
break;
case uint16:
set_iinfo_limits<uint16_t>(min, max);
break;
case int32:
set_iinfo_limits<int32_t>(min, max);
break;
case uint32:
set_iinfo_limits<uint32_t>(min, max);
break;
case int64:
set_iinfo_limits<int64_t>(min, max);
break;
case uint64:
set_iinfo_limits<uint64_t>(min, max);
break;
default:
std::ostringstream msg;
msg << "[iinfo] dtype " << dtype << " is not integral.";
throw std::invalid_argument(msg.str());
}
}
} // namespace mlx::core

View File

@@ -67,6 +67,14 @@ struct finfo {
double max;
};
/** Holds information about integral types. */
struct iinfo {
explicit iinfo(Dtype dtype);
Dtype dtype;
int64_t min;
uint64_t max;
};
/** The type from promoting the arrays' types with one another. */
inline Dtype result_type(const array& a, const array& b) {
return promote_types(a.dtype(), b.dtype());