add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
}
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
} // namespace
@@ -1140,7 +1140,7 @@ array array_equal(
return array(false);
} else {
auto dtype = promote_types(a.dtype(), b.dtype());
equal_nan &= is_floating_point(dtype);
equal_nan &= issubdtype(dtype, inexact);
return all(
array(
a.shape(),
@@ -1153,7 +1153,7 @@ array array_equal(
}
array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return not_equal(a, a, s);
@@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
@@ -1929,7 +1929,7 @@ array floor_divide(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) {
if (issubdtype(dtype, inexact)) {
return floor(divide(a, b, s), s);
}
@@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_complex(dtype)) {
if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs =
@@ -2220,7 +2220,7 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
@@ -2330,7 +2330,7 @@ array gather(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[gather] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2521,7 +2521,7 @@ array scatter(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[scatter] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
}
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
<< " Convolution currently only supports floating point types";
@@ -3062,7 +3062,7 @@ array quantized_matmul(
}
auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) {
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype()
@@ -3364,7 +3364,7 @@ array addmm(
// Type promotion
auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()