mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2848,6 +2848,192 @@ TEST_CASE("test diag") {
|
||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test issubdtype") {
|
||||
const auto cats = {
|
||||
complexfloating,
|
||||
floating,
|
||||
inexact,
|
||||
signedinteger,
|
||||
unsignedinteger,
|
||||
integer,
|
||||
number,
|
||||
generic};
|
||||
const auto types = {
|
||||
bool_,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
bfloat16,
|
||||
complex64};
|
||||
for (const auto& type : types) {
|
||||
CHECK(issubdtype(type, type));
|
||||
CHECK(issubdtype(type, generic));
|
||||
switch (kindof(type)) {
|
||||
case Dtype::Kind::b:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK_FALSE(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::u:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK(issubdtype(type, unsignedinteger));
|
||||
CHECK(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::i:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::f:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::c:
|
||||
CHECK(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::V:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& type : types) {
|
||||
CHECK(issubdtype(type, type));
|
||||
CHECK(issubdtype(type, generic));
|
||||
for (auto type1 : types) {
|
||||
CHECK_EQ(issubdtype(type, type1), type == type1);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& cat : cats) {
|
||||
CHECK(issubdtype(cat, cat));
|
||||
switch (cat) {
|
||||
case Dtype::Category::complexfloating:
|
||||
CHECK(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::floating:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::inexact:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::signedinteger:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::unsignedinteger:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::integer:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::number:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::generic:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK_FALSE(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_1d(x);
|
||||
|
Reference in New Issue
Block a user