Add mx.finfo and use it when making causal mask (#1726)

* finfo

* fixes

* docs
This commit is contained in:
Awni Hannun
2024-12-19 14:52:41 -08:00
committed by GitHub
parent e03f0372b1
commit c3628eea49
9 changed files with 154 additions and 3 deletions

View File

@@ -53,6 +53,14 @@ struct PrintFormatter {
PrintFormatter& get_global_formatter();
/** Holds information about floating-point types. */
struct finfo {
explicit finfo(Dtype dtype);
Dtype dtype;
float min;
float max;
};
/** The type from promoting the arrays' types with one another. */
inline Dtype result_type(const array& a, const array& b) {
return promote_types(a.dtype(), b.dtype());