add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -488,7 +488,7 @@ void steel_matmul(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
@@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}

View File

@@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
} else {
// No-op integer types

View File

@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[softmax] Does not support non-floating point types.");
}