mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -488,7 +488,7 @@ void steel_matmul(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
@@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
|
@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
|
Reference in New Issue
Block a user