mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Add mx.finfo
and use it when making causal mask (#1726)
* finfo * fixes * docs
This commit is contained in:
@@ -179,6 +179,31 @@ void init_array(nb::module_& m) {
|
||||
.value("number", mx::number)
|
||||
.value("generic", mx::generic)
|
||||
.export_values();
|
||||
|
||||
nb::class_<mx::finfo>(
|
||||
m,
|
||||
"finfo",
|
||||
R"pbdoc(
|
||||
Get information on floating-point types.
|
||||
)pbdoc")
|
||||
.def(nb::init<mx::Dtype>())
|
||||
.def_ro(
|
||||
"min",
|
||||
&mx::finfo::min,
|
||||
R"pbdoc(The smallest representable number.)pbdoc")
|
||||
.def_ro(
|
||||
"max",
|
||||
&mx::finfo::max,
|
||||
R"pbdoc(The largest representable number.)pbdoc")
|
||||
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||
.def("__repr__", [](const mx::finfo& f) {
|
||||
std::ostringstream os;
|
||||
os << "finfo("
|
||||
<< "min=" << f.min << ", max=" << f.max << ", dtype=" << f.dtype
|
||||
<< ")";
|
||||
return os.str();
|
||||
});
|
||||
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
"ArrayAt",
|
||||
|
Reference in New Issue
Block a user