mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array l2_norm(
|
||||
@@ -19,7 +19,7 @@ inline array l2_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (is_complex(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), complexfloating)) {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
} else {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
|
||||
Reference in New Issue
Block a user