mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
26
mlx/ops.cpp
26
mlx/ops.cpp
@@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -1140,7 +1140,7 @@ array array_equal(
|
||||
return array(false);
|
||||
} else {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
equal_nan &= is_floating_point(dtype);
|
||||
equal_nan &= issubdtype(dtype, inexact);
|
||||
return all(
|
||||
array(
|
||||
a.shape(),
|
||||
@@ -1153,7 +1153,7 @@ array array_equal(
|
||||
}
|
||||
|
||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return not_equal(a, a, s);
|
||||
@@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
@@ -1929,7 +1929,7 @@ array floor_divide(
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_floating_point(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
@@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
|
||||
std::vector<array>
|
||||
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_complex(dtype)) {
|
||||
if (issubdtype(dtype, complexfloating)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs =
|
||||
@@ -2220,7 +2220,7 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
@@ -2330,7 +2330,7 @@ array gather(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@@ -2521,7 +2521,7 @@ array scatter(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[scatter] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
|
||||
}
|
||||
|
||||
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
||||
<< " Convolution currently only supports floating point types";
|
||||
@@ -3062,7 +3062,7 @@ array quantized_matmul(
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
@@ -3364,7 +3364,7 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
||||
|
||||
Reference in New Issue
Block a user