Add mx.finfo and use it when making causal mask (#1726)

* finfo

* fixes

* docs
This commit is contained in:
Awni Hannun
2024-12-19 14:52:41 -08:00
committed by GitHub
parent e03f0372b1
commit c3628eea49
9 changed files with 154 additions and 3 deletions

View File

@@ -179,6 +179,31 @@ void init_array(nb::module_& m) {
.value("number", mx::number)
.value("generic", mx::generic)
.export_values();
nb::class_<mx::finfo>(
m,
"finfo",
R"pbdoc(
Get information on floating-point types.
)pbdoc")
.def(nb::init<mx::Dtype>())
.def_ro(
"min",
&mx::finfo::min,
R"pbdoc(The smallest representable number.)pbdoc")
.def_ro(
"max",
&mx::finfo::max,
R"pbdoc(The largest representable number.)pbdoc")
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
.def("__repr__", [](const mx::finfo& f) {
std::ostringstream os;
os << "finfo("
<< "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype
<< ")";
return os.str();
});
nb::class_<ArrayAt>(
m,
"ArrayAt",