add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -488,7 +488,7 @@ void steel_matmul(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
@@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}