add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -2848,6 +2848,192 @@ TEST_CASE("test diag") {
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
}
TEST_CASE("test issubdtype") {
const auto cats = {
complexfloating,
floating,
inexact,
signedinteger,
unsignedinteger,
integer,
number,
generic};
const auto types = {
bool_,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
float16,
float32,
bfloat16,
complex64};
for (const auto& type : types) {
CHECK(issubdtype(type, type));
CHECK(issubdtype(type, generic));
switch (kindof(type)) {
case Dtype::Kind::b:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK_FALSE(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::u:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK(issubdtype(type, unsignedinteger));
CHECK(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::i:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::f:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::c:
CHECK(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::V:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
}
}
for (const auto& type : types) {
CHECK(issubdtype(type, type));
CHECK(issubdtype(type, generic));
for (auto type1 : types) {
CHECK_EQ(issubdtype(type, type1), type == type1);
}
}
for (const auto& cat : cats) {
CHECK(issubdtype(cat, cat));
switch (cat) {
case Dtype::Category::complexfloating:
CHECK(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::floating:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::inexact:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::signedinteger:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::unsignedinteger:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::integer:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::number:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::generic:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK_FALSE(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
}
}
}
TEST_CASE("test atleast_1d") {
auto x = array(1);
auto out = atleast_1d(x);