mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
bfb5bad4f0
commit
479051ce1c
@ -58,6 +58,7 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
python/array
|
python/array
|
||||||
|
python/data_types
|
||||||
python/devices_and_streams
|
python/devices_and_streams
|
||||||
python/ops
|
python/ops
|
||||||
python/random
|
python/random
|
||||||
|
@ -19,7 +19,6 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
Dtype
|
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@ -32,7 +31,6 @@ Array
|
|||||||
array.cumsum
|
array.cumsum
|
||||||
array.diag
|
array.diag
|
||||||
array.diagonal
|
array.diagonal
|
||||||
array.dtype
|
|
||||||
array.exp
|
array.exp
|
||||||
array.flatten
|
array.flatten
|
||||||
array.log
|
array.log
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
.. _data_types:
|
.. _data_types:
|
||||||
|
|
||||||
:orphan:
|
|
||||||
|
|
||||||
Data Types
|
Data Types
|
||||||
==========
|
==========
|
||||||
|
|
||||||
@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``complex64``
|
* - ``complex64``
|
||||||
- 8
|
- 8
|
||||||
- 64-bit complex float
|
- 64-bit complex float
|
||||||
|
|
||||||
|
|
||||||
|
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||||
|
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||||
|
``dtype`` (or category) is a subtype of another category.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Dtype
|
||||||
|
DtypeCategory
|
||||||
|
issubdtype
|
||||||
|
@ -30,6 +30,7 @@ Module
|
|||||||
Module.named_modules
|
Module.named_modules
|
||||||
Module.parameters
|
Module.parameters
|
||||||
Module.save_weights
|
Module.save_weights
|
||||||
|
Module.set_dtype
|
||||||
Module.train
|
Module.train
|
||||||
Module.trainable_parameters
|
Module.trainable_parameters
|
||||||
Module.unfreeze
|
Module.unfreeze
|
||||||
|
@ -62,10 +62,10 @@ Operations
|
|||||||
identity
|
identity
|
||||||
inner
|
inner
|
||||||
isclose
|
isclose
|
||||||
isnan
|
|
||||||
isposinf
|
|
||||||
isneginf
|
|
||||||
isinf
|
isinf
|
||||||
|
isnan
|
||||||
|
isneginf
|
||||||
|
isposinf
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
linspace
|
linspace
|
||||||
|
@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (is_floating_point(out.dtype())) {
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -355,7 +355,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvlog1pf(
|
vvlog1pf(
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (is_floating_point(out.dtype())) {
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
if (is_floating_point(out.dtype())) {
|
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||||
} else if (out.dtype() == float16) {
|
} else if (out.dtype() == float16) {
|
||||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == bfloat16) {
|
||||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else {
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
std::ostringstream err;
|
std::ostringstream err;
|
||||||
err << "[logaddexp] Does not support " << out.dtype();
|
err << "[logaddexp] Does not support " << out.dtype();
|
||||||
throw std::invalid_argument(err.str());
|
throw std::invalid_argument(err.str());
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||||
|
@ -22,7 +22,7 @@ namespace mlx::core {
|
|||||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (is_unsigned(in.dtype())) {
|
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||||
// No-op for unsigned types
|
// No-op for unsigned types
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcCos());
|
unary_fp(in, out, detail::ArcCos());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcCosh());
|
unary_fp(in, out, detail::ArcCosh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcSin());
|
unary_fp(in, out, detail::ArcSin());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcSinh());
|
unary_fp(in, out, detail::ArcSinh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcTan());
|
unary_fp(in, out, detail::ArcTan());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::ArcTanh());
|
unary_fp(in, out, detail::ArcTanh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Ceil());
|
unary_fp(in, out, detail::Ceil());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Cos());
|
unary_fp(in, out, detail::Cos());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Cosh());
|
unary_fp(in, out, detail::Cosh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Exp());
|
unary_fp(in, out, detail::Exp());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -362,7 +362,7 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Floor());
|
unary_fp(in, out, detail::Floor());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
@ -388,7 +388,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
switch (base_) {
|
switch (base_) {
|
||||||
case Base::e:
|
case Base::e:
|
||||||
unary_fp(in, out, detail::Log());
|
unary_fp(in, out, detail::Log());
|
||||||
@ -410,7 +410,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Log1p());
|
unary_fp(in, out, detail::Log1p());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -597,7 +597,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Round());
|
unary_fp(in, out, detail::Round());
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
@ -608,7 +608,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Sigmoid());
|
unary_fp(in, out, detail::Sigmoid());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -630,7 +630,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Sin());
|
unary_fp(in, out, detail::Sin());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -642,7 +642,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Sinh());
|
unary_fp(in, out, detail::Sinh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -850,7 +850,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Tan());
|
unary_fp(in, out, detail::Tan());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -862,7 +862,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (is_floating_point(out.dtype())) {
|
if (issubdtype(out.dtype(), inexact)) {
|
||||||
unary_fp(in, out, detail::Tanh());
|
unary_fp(in, out, detail::Tanh());
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -222,7 +222,7 @@ void scan_dispatch(
|
|||||||
}
|
}
|
||||||
case Scan::Min: {
|
case Scan::Min: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||||
auto init = (is_floating_point(input.dtype()))
|
auto init = (issubdtype(input.dtype(), floating))
|
||||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||||
: std::numeric_limits<U>::max();
|
: std::numeric_limits<U>::max();
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
@ -232,7 +232,7 @@ void scan_dispatch(
|
|||||||
}
|
}
|
||||||
case Scan::Max: {
|
case Scan::Max: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||||
auto init = (is_floating_point(input.dtype()))
|
auto init = (issubdtype(input.dtype(), floating))
|
||||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
: std::numeric_limits<U>::max();
|
: std::numeric_limits<U>::max();
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||||
|
@ -488,7 +488,7 @@ void steel_matmul(
|
|||||||
|
|
||||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (!is_floating_point(out.dtype())) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul] Does not yet support non-floating point types.");
|
"[matmul] Does not yet support non-floating point types.");
|
||||||
}
|
}
|
||||||
@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
if (!is_floating_point(out.dtype())) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul] Does not yet support non-floating point types.");
|
"[matmul] Does not yet support non-floating point types.");
|
||||||
}
|
}
|
||||||
|
@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
if (not is_integral(in.dtype())) {
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
unary_op(inputs, out, "round");
|
unary_op(inputs, out, "round");
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
|
@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out) {
|
array& out) {
|
||||||
assert(inputs.size() >= 3);
|
assert(inputs.size() >= 3);
|
||||||
if (!is_floating_point(out.dtype())) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
if (!is_floating_point(out.dtype())) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[softmax] Does not support non-floating point types.");
|
"[softmax] Does not support non-floating point types.");
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -12,6 +12,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int num_types = 13;
|
constexpr int num_types = 13;
|
||||||
|
constexpr int num_cats = 8;
|
||||||
|
|
||||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||||
Dtype::Kind::b, // bool_,
|
Dtype::Kind::b, // bool_,
|
||||||
@ -49,6 +50,35 @@ constexpr Dtype type_rules[num_types][num_types] = {
|
|||||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
constexpr bool subcategory_to_category[num_cats][num_cats] = {
|
||||||
|
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
|
||||||
|
{true, false, true, false, false, false, true, true}, // complexfloating
|
||||||
|
{false, true, true, false, false, false, true, true}, // floating
|
||||||
|
{false, false, true, false, false, false, true, true}, // inexact
|
||||||
|
{false, false, false, true, false, true, true, true}, // signedinteger
|
||||||
|
{false, false, false, false, true, true, true, true}, // unsignedinteger
|
||||||
|
{false, false, false, false, false, true, true, true}, // integer
|
||||||
|
{false, false, false, false, false, false, true, true}, // number
|
||||||
|
{false, false, false, false, false, false, false, true}, // generic
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr Dtype::Category type_to_category[num_types] = {
|
||||||
|
Dtype::Category::generic, // bool_,
|
||||||
|
Dtype::Category::unsignedinteger, // uint8,
|
||||||
|
Dtype::Category::unsignedinteger, // uint16,
|
||||||
|
Dtype::Category::unsignedinteger, // uint32,
|
||||||
|
Dtype::Category::unsignedinteger, // uint64,
|
||||||
|
Dtype::Category::signedinteger, // int8,
|
||||||
|
Dtype::Category::signedinteger, // int16,
|
||||||
|
Dtype::Category::signedinteger, // int32,
|
||||||
|
Dtype::Category::signedinteger, // int64,
|
||||||
|
Dtype::Category::floating, // float16,
|
||||||
|
Dtype::Category::floating, // float32,
|
||||||
|
Dtype::Category::floating, // bfloat16,
|
||||||
|
Dtype::Category::complexfloating, // complex64,
|
||||||
|
};
|
||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
inline bool is_big_endian() {
|
inline bool is_big_endian() {
|
||||||
@ -141,6 +171,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
|
|||||||
return complex64;
|
return complex64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool issubdtype(const Dtype& a, const Dtype& b) {
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||||
|
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||||
|
return subcategory_to_category[static_cast<uint32_t>(a)]
|
||||||
|
[static_cast<uint32_t>(b)];
|
||||||
|
}
|
||||||
|
|
||||||
// Array protocol typestring for Dtype
|
// Array protocol typestring for Dtype
|
||||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||||
std::ostringstream r;
|
std::ostringstream r;
|
||||||
|
46
mlx/dtype.h
46
mlx/dtype.h
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -38,6 +38,17 @@ struct Dtype {
|
|||||||
V, /* void - used for brain float */
|
V, /* void - used for brain float */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class Category {
|
||||||
|
complexfloating,
|
||||||
|
floating,
|
||||||
|
inexact,
|
||||||
|
signedinteger,
|
||||||
|
unsignedinteger,
|
||||||
|
integer,
|
||||||
|
number,
|
||||||
|
generic
|
||||||
|
};
|
||||||
|
|
||||||
Val val;
|
Val val;
|
||||||
const uint8_t size;
|
const uint8_t size;
|
||||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||||
@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
|||||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||||
|
|
||||||
|
inline constexpr Dtype::Category complexfloating =
|
||||||
|
Dtype::Category::complexfloating;
|
||||||
|
inline constexpr Dtype::Category floating = Dtype::Category::floating;
|
||||||
|
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
|
||||||
|
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
|
||||||
|
inline constexpr Dtype::Category unsignedinteger =
|
||||||
|
Dtype::Category::unsignedinteger;
|
||||||
|
inline constexpr Dtype::Category integer = Dtype::Category::integer;
|
||||||
|
inline constexpr Dtype::Category number = Dtype::Category::number;
|
||||||
|
inline constexpr Dtype::Category generic = Dtype::Category::generic;
|
||||||
|
|
||||||
|
bool issubdtype(const Dtype& a, const Dtype& b);
|
||||||
|
bool issubdtype(const Dtype::Category& a, const Dtype& b);
|
||||||
|
bool issubdtype(const Dtype& a, const Dtype::Category& b);
|
||||||
|
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||||
|
|
||||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||||
|
|
||||||
inline uint8_t size_of(const Dtype& t) {
|
inline uint8_t size_of(const Dtype& t) {
|
||||||
@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
|
|||||||
|
|
||||||
Dtype::Kind kindof(const Dtype& t);
|
Dtype::Kind kindof(const Dtype& t);
|
||||||
|
|
||||||
inline bool is_unsigned(const Dtype& t) {
|
|
||||||
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_floating_point(const Dtype& t) {
|
|
||||||
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
|
||||||
kindof(t) == Dtype::Kind::c;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_complex(const Dtype& t) {
|
|
||||||
return kindof(t) == Dtype::Kind::c;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_integral(const Dtype& t) {
|
|
||||||
return !(is_floating_point(t));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TypeToDtype {
|
struct TypeToDtype {
|
||||||
operator Dtype();
|
operator Dtype();
|
||||||
|
@ -64,7 +64,7 @@ array rms_norm(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
auto out_type = result_type(x, weight);
|
auto out_type = result_type(x, weight);
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
@ -128,7 +128,7 @@ array layer_norm(
|
|||||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||||
: result_type(x, *weight))
|
: result_type(x, *weight))
|
||||||
: x.dtype();
|
: x.dtype();
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
@ -319,7 +319,7 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto final_type = result_type(queries, keys, values);
|
auto final_type = result_type(queries, keys, values);
|
||||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
if (!issubdtype(final_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||||
<< final_type << ".";
|
<< final_type << ".";
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array l2_norm(
|
inline array l2_norm(
|
||||||
@ -19,7 +19,7 @@ inline array l2_norm(
|
|||||||
const std::vector<int>& axis,
|
const std::vector<int>& axis,
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (is_complex(a.dtype())) {
|
if (issubdtype(a.dtype(), complexfloating)) {
|
||||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||||
} else {
|
} else {
|
||||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
|
26
mlx/ops.cpp
26
mlx/ops.cpp
@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -1140,7 +1140,7 @@ array array_equal(
|
|||||||
return array(false);
|
return array(false);
|
||||||
} else {
|
} else {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
equal_nan &= is_floating_point(dtype);
|
equal_nan &= issubdtype(dtype, inexact);
|
||||||
return all(
|
return all(
|
||||||
array(
|
array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
@ -1153,7 +1153,7 @@ array array_equal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (is_integral(a.dtype())) {
|
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
}
|
}
|
||||||
return not_equal(a, a, s);
|
return not_equal(a, a, s);
|
||||||
@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (is_integral(a.dtype())) {
|
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
}
|
}
|
||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (is_integral(a.dtype())) {
|
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
}
|
}
|
||||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
@ -1929,7 +1929,7 @@ array floor_divide(
|
|||||||
const array& b,
|
const array& b,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
if (is_floating_point(dtype)) {
|
if (issubdtype(dtype, inexact)) {
|
||||||
return floor(divide(a, b, s), s);
|
return floor(divide(a, b, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
if (is_complex(dtype)) {
|
if (issubdtype(dtype, complexfloating)) {
|
||||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||||
}
|
}
|
||||||
auto inputs =
|
auto inputs =
|
||||||
@ -2220,7 +2220,7 @@ array matmul(
|
|||||||
}
|
}
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[matmul] Only real floating point types are supported but "
|
msg << "[matmul] Only real floating point types are supported but "
|
||||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||||
@ -2330,7 +2330,7 @@ array gather(
|
|||||||
|
|
||||||
// Promote indices to the same type
|
// Promote indices to the same type
|
||||||
auto dtype = result_type(indices);
|
auto dtype = result_type(indices);
|
||||||
if (!is_integral(dtype)) {
|
if (issubdtype(dtype, inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[gather] Got indices with invalid dtype. Indices must be integral.");
|
"[gather] Got indices with invalid dtype. Indices must be integral.");
|
||||||
}
|
}
|
||||||
@ -2521,7 +2521,7 @@ array scatter(
|
|||||||
|
|
||||||
// Promote indices to the same type
|
// Promote indices to the same type
|
||||||
auto dtype = result_type(indices);
|
auto dtype = result_type(indices);
|
||||||
if (!is_integral(dtype)) {
|
if (issubdtype(dtype, inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[scatter] Got indices with invalid dtype. Indices must be integral.");
|
"[scatter] Got indices with invalid dtype. Indices must be integral.");
|
||||||
}
|
}
|
||||||
@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||||
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
|
if (!issubdtype(in.dtype(), floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
||||||
<< " Convolution currently only supports floating point types";
|
<< " Convolution currently only supports floating point types";
|
||||||
@ -3062,7 +3062,7 @@ array quantized_matmul(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto dtype = result_type(x, scales, biases);
|
auto dtype = result_type(x, scales, biases);
|
||||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||||
<< "the passed types where x.dtype() == " << x.dtype()
|
<< "the passed types where x.dtype() == " << x.dtype()
|
||||||
@ -3364,7 +3364,7 @@ array addmm(
|
|||||||
|
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = result_type(a, b, c);
|
auto out_type = result_type(a, b, c);
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[addmm] Only real floating point types are supported but "
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
||||||
|
@ -97,7 +97,7 @@ array uniform(
|
|||||||
Dtype dtype /* = float32 */,
|
Dtype dtype /* = float32 */,
|
||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!is_floating_point(dtype) && !is_complex(dtype)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Can only generate uniform numbers with real floating point type.");
|
"Can only generate uniform numbers with real floating point type.");
|
||||||
}
|
}
|
||||||
@ -179,7 +179,7 @@ array randint(
|
|||||||
Dtype dtype /* = int32 */,
|
Dtype dtype /* = int32 */,
|
||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!is_integral(dtype)) {
|
if (issubdtype(dtype, inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[randint] randint only accepts integer dtypes and bool.");
|
"[randint] randint only accepts integer dtypes and bool.");
|
||||||
}
|
}
|
||||||
@ -192,7 +192,7 @@ array bernoulli(
|
|||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::optional<array>& key /*= nullopt */,
|
const std::optional<array>& key /*= nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!is_floating_point(p.dtype())) {
|
if (!issubdtype(p.dtype(), floating)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[bernoulli] bernoulli probability `p` must be a float type.");
|
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ array truncated_normal(
|
|||||||
// Same as
|
// Same as
|
||||||
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
|
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
|
||||||
|
|
||||||
if (!is_floating_point(dtype)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
|
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
|
||||||
}
|
}
|
||||||
|
@ -578,3 +578,26 @@ class Module(dict):
|
|||||||
See :func:`train`.
|
See :func:`train`.
|
||||||
"""
|
"""
|
||||||
self.train(False)
|
self.train(False)
|
||||||
|
|
||||||
|
def set_dtype(
|
||||||
|
self,
|
||||||
|
dtype: mx.Dtype,
|
||||||
|
predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(
|
||||||
|
x, mx.floating
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Set the dtype of the module's parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype): The new dtype.
|
||||||
|
predicate (typing.Callable, optional): A predicate to select
|
||||||
|
parameters to cast. By default, only parameters of type
|
||||||
|
:attr:`floating` will be updated to avoid casting integer
|
||||||
|
parameters to the new dtype.
|
||||||
|
"""
|
||||||
|
if predicate is None:
|
||||||
|
|
||||||
|
def predicate(_):
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||||
|
@ -254,7 +254,7 @@ array array_from_list(
|
|||||||
std::vector<uint32_t> vals;
|
std::vector<uint32_t> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return array(vals.begin(), shape, dtype);
|
||||||
} else if (is_floating_point(dtype)) {
|
} else if (issubdtype(dtype, inexact)) {
|
||||||
std::vector<float> vals;
|
std::vector<float> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return array(vals.begin(), shape, dtype);
|
||||||
@ -439,6 +439,54 @@ void init_array(nb::module_& m) {
|
|||||||
m.attr("float32") = nb::cast(float32);
|
m.attr("float32") = nb::cast(float32);
|
||||||
m.attr("bfloat16") = nb::cast(bfloat16);
|
m.attr("bfloat16") = nb::cast(bfloat16);
|
||||||
m.attr("complex64") = nb::cast(complex64);
|
m.attr("complex64") = nb::cast(complex64);
|
||||||
|
nb::class_<Dtype::Category>(
|
||||||
|
m,
|
||||||
|
"DtypeCategory",
|
||||||
|
R"pbdoc(
|
||||||
|
Type to hold categories of :class:`dtypes <Dtype>`.
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.generic`
|
||||||
|
|
||||||
|
* :ref:`bool_ <data_types>`
|
||||||
|
* :attr:`~mlx.core.number`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.integer`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.unsignedinteger`
|
||||||
|
|
||||||
|
* :ref:`uint8 <data_types>`
|
||||||
|
* :ref:`uint16 <data_types>`
|
||||||
|
* :ref:`uint32 <data_types>`
|
||||||
|
* :ref:`uint64 <data_types>`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.signedinteger`
|
||||||
|
|
||||||
|
* :ref:`int8 <data_types>`
|
||||||
|
* :ref:`int32 <data_types>`
|
||||||
|
* :ref:`int64 <data_types>`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.inexact`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.floating`
|
||||||
|
|
||||||
|
* :ref:`float16 <data_types>`
|
||||||
|
* :ref:`bfloat16 <data_types>`
|
||||||
|
* :ref:`float32 <data_types>`
|
||||||
|
|
||||||
|
* :attr:`~mlx.core.complexfloating`
|
||||||
|
|
||||||
|
* :ref:`complex128 <data_types>`
|
||||||
|
|
||||||
|
See also :func:`~mlx.core.issubdtype`.
|
||||||
|
)pbdoc");
|
||||||
|
m.attr("complexfloating") = nb::cast(complexfloating);
|
||||||
|
m.attr("floating") = nb::cast(floating);
|
||||||
|
m.attr("inexact") = nb::cast(inexact);
|
||||||
|
m.attr("signedinteger") = nb::cast(signedinteger);
|
||||||
|
m.attr("unsignedinteger") = nb::cast(unsignedinteger);
|
||||||
|
m.attr("integer") = nb::cast(integer);
|
||||||
|
m.attr("number") = nb::cast(number);
|
||||||
|
m.attr("generic") = nb::cast(generic);
|
||||||
|
|
||||||
nb::class_<ArrayAt>(
|
nb::class_<ArrayAt>(
|
||||||
m,
|
m,
|
||||||
@ -700,7 +748,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__itruediv__",
|
"__itruediv__",
|
||||||
[](array& a, const ScalarOrArray v) -> array& {
|
[](array& a, const ScalarOrArray v) -> array& {
|
||||||
if (!is_floating_point(a.dtype())) {
|
if (!issubdtype(a.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"In place division cannot cast to non-floating point type.");
|
"In place division cannot cast to non-floating point type.");
|
||||||
}
|
}
|
||||||
@ -852,7 +900,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__invert__",
|
"__invert__",
|
||||||
[](const array& a) {
|
[](const array& a) {
|
||||||
if (is_floating_point(a.dtype())) {
|
if (issubdtype(a.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Floating point types not allowed with or bitwise inversion.");
|
"Floating point types not allowed with or bitwise inversion.");
|
||||||
}
|
}
|
||||||
@ -866,7 +914,8 @@ void init_array(nb::module_& m) {
|
|||||||
"__and__",
|
"__and__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
auto b = to_array(v, a.dtype());
|
auto b = to_array(v, a.dtype());
|
||||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
if (issubdtype(a.dtype(), inexact) ||
|
||||||
|
issubdtype(b.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Floating point types not allowed with bitwise and.");
|
"Floating point types not allowed with bitwise and.");
|
||||||
}
|
}
|
||||||
@ -881,7 +930,8 @@ void init_array(nb::module_& m) {
|
|||||||
"__iand__",
|
"__iand__",
|
||||||
[](array& a, const ScalarOrArray v) -> array& {
|
[](array& a, const ScalarOrArray v) -> array& {
|
||||||
auto b = to_array(v, a.dtype());
|
auto b = to_array(v, a.dtype());
|
||||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
if (issubdtype(a.dtype(), inexact) ||
|
||||||
|
issubdtype(b.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Floating point types not allowed with bitwise and.");
|
"Floating point types not allowed with bitwise and.");
|
||||||
}
|
}
|
||||||
@ -898,7 +948,8 @@ void init_array(nb::module_& m) {
|
|||||||
"__or__",
|
"__or__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
auto b = to_array(v, a.dtype());
|
auto b = to_array(v, a.dtype());
|
||||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
if (issubdtype(a.dtype(), inexact) ||
|
||||||
|
issubdtype(b.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Floating point types not allowed with or bitwise or.");
|
"Floating point types not allowed with or bitwise or.");
|
||||||
}
|
}
|
||||||
@ -913,7 +964,8 @@ void init_array(nb::module_& m) {
|
|||||||
"__ior__",
|
"__ior__",
|
||||||
[](array& a, const ScalarOrArray v) -> array& {
|
[](array& a, const ScalarOrArray v) -> array& {
|
||||||
auto b = to_array(v, a.dtype());
|
auto b = to_array(v, a.dtype());
|
||||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
if (issubdtype(a.dtype(), inexact) ||
|
||||||
|
issubdtype(b.dtype(), inexact)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Floating point types not allowed with or bitwise or.");
|
"Floating point types not allowed with or bitwise or.");
|
||||||
}
|
}
|
||||||
|
@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array or list(array): An array or list of arrays with at least three dimensions.
|
array or list(array): An array or list of arrays with at least three dimensions.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"issubdtype",
|
||||||
|
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||||
|
""_a,
|
||||||
|
""_a,
|
||||||
|
R"pbdoc(
|
||||||
|
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||||
|
of another.
|
||||||
|
|
||||||
|
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||||
|
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||||
|
True
|
||||||
|
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||||
|
False
|
||||||
|
|
||||||
|
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||||
|
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||||
|
False
|
||||||
|
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||||
|
True
|
||||||
|
|
||||||
|
Similar types of different sizes are not subdtypes of each other:
|
||||||
|
|
||||||
|
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||||
|
False
|
||||||
|
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||||
|
False
|
||||||
|
|
||||||
|
but both are subtypes of `floating`:
|
||||||
|
|
||||||
|
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||||
|
True
|
||||||
|
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||||
|
True
|
||||||
|
|
||||||
|
For convenience, dtype-like objects are allowed too:
|
||||||
|
|
||||||
|
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||||
|
True
|
||||||
|
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||||
|
False
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"issubdtype",
|
||||||
|
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||||
|
""_a,
|
||||||
|
""_a);
|
||||||
|
m.def(
|
||||||
|
"issubdtype",
|
||||||
|
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||||
|
""_a,
|
||||||
|
""_a);
|
||||||
|
m.def(
|
||||||
|
"issubdtype",
|
||||||
|
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||||
|
&issubdtype),
|
||||||
|
""_a,
|
||||||
|
""_a);
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ inline array to_array(
|
|||||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||||
auto out_t = dtype.value_or(float32);
|
auto out_t = dtype.value_or(float32);
|
||||||
return array(
|
return array(
|
||||||
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
return array(static_cast<complex64_t>(*pv), complex64);
|
return array(static_cast<complex64_t>(*pv), complex64);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_set_dtype(self):
|
||||||
|
def assert_dtype(layer, dtype):
|
||||||
|
for k, v in tree_flatten(layer.parameters()):
|
||||||
|
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
|
||||||
|
|
||||||
|
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
|
||||||
|
assert_dtype(layer, mx.float32)
|
||||||
|
|
||||||
|
layer.set_dtype(mx.bfloat16)
|
||||||
|
assert_dtype(layer, mx.bfloat16)
|
||||||
|
|
||||||
|
layer.set_dtype(mx.float32, lambda x: False)
|
||||||
|
assert_dtype(layer, mx.bfloat16)
|
||||||
|
|
||||||
|
layer.set_dtype(mx.int32, lambda x: True)
|
||||||
|
assert_dtype(layer, mx.int32)
|
||||||
|
|
||||||
|
layer.set_dtype(mx.int64, predicate=None)
|
||||||
|
assert_dtype(layer, mx.int64)
|
||||||
|
|
||||||
|
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
|
||||||
|
assert_dtype(layer, mx.int16)
|
||||||
|
|
||||||
def test_rnn(self):
|
def test_rnn(self):
|
||||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||||
inp = mx.random.normal((2, 25, 5))
|
inp = mx.random.normal((2, 25, 5))
|
||||||
|
@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||||
|
|
||||||
|
def test_issubdtype(self):
|
||||||
|
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||||
|
|
||||||
|
cats = [
|
||||||
|
"complexfloating",
|
||||||
|
"floating",
|
||||||
|
"inexact",
|
||||||
|
"signedinteger",
|
||||||
|
"unsignedinteger",
|
||||||
|
"integer",
|
||||||
|
"number",
|
||||||
|
"generic",
|
||||||
|
"bool_",
|
||||||
|
"uint8",
|
||||||
|
"uint16",
|
||||||
|
"uint32",
|
||||||
|
"uint64",
|
||||||
|
"int8",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"float16",
|
||||||
|
"float32",
|
||||||
|
"complex64",
|
||||||
|
]
|
||||||
|
|
||||||
|
for a in cats:
|
||||||
|
for b in cats:
|
||||||
|
self.assertEqual(
|
||||||
|
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
|
||||||
|
np.issubdtype(getattr(np, a), getattr(np, b)),
|
||||||
|
f"mx and np don't aggree on {a}, {b}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2848,6 +2848,192 @@ TEST_CASE("test diag") {
|
|||||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test issubdtype") {
|
||||||
|
const auto cats = {
|
||||||
|
complexfloating,
|
||||||
|
floating,
|
||||||
|
inexact,
|
||||||
|
signedinteger,
|
||||||
|
unsignedinteger,
|
||||||
|
integer,
|
||||||
|
number,
|
||||||
|
generic};
|
||||||
|
const auto types = {
|
||||||
|
bool_,
|
||||||
|
uint8,
|
||||||
|
uint16,
|
||||||
|
uint32,
|
||||||
|
uint64,
|
||||||
|
int8,
|
||||||
|
int16,
|
||||||
|
int32,
|
||||||
|
int64,
|
||||||
|
float16,
|
||||||
|
float32,
|
||||||
|
bfloat16,
|
||||||
|
complex64};
|
||||||
|
for (const auto& type : types) {
|
||||||
|
CHECK(issubdtype(type, type));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
switch (kindof(type)) {
|
||||||
|
case Dtype::Kind::b:
|
||||||
|
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(type, floating));
|
||||||
|
CHECK_FALSE(issubdtype(type, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, integer));
|
||||||
|
CHECK_FALSE(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Kind::u:
|
||||||
|
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(type, floating));
|
||||||
|
CHECK_FALSE(issubdtype(type, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||||
|
CHECK(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK(issubdtype(type, integer));
|
||||||
|
CHECK(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Kind::i:
|
||||||
|
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(type, floating));
|
||||||
|
CHECK_FALSE(issubdtype(type, inexact));
|
||||||
|
CHECK(issubdtype(type, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK(issubdtype(type, integer));
|
||||||
|
CHECK(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Kind::f:
|
||||||
|
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||||
|
CHECK(issubdtype(type, floating));
|
||||||
|
CHECK(issubdtype(type, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, integer));
|
||||||
|
CHECK(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Kind::c:
|
||||||
|
CHECK(issubdtype(type, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(type, floating));
|
||||||
|
CHECK(issubdtype(type, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, integer));
|
||||||
|
CHECK(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Kind::V:
|
||||||
|
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||||
|
CHECK(issubdtype(type, floating));
|
||||||
|
CHECK(issubdtype(type, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(type, integer));
|
||||||
|
CHECK(issubdtype(type, number));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& type : types) {
|
||||||
|
CHECK(issubdtype(type, type));
|
||||||
|
CHECK(issubdtype(type, generic));
|
||||||
|
for (auto type1 : types) {
|
||||||
|
CHECK_EQ(issubdtype(type, type1), type == type1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& cat : cats) {
|
||||||
|
CHECK(issubdtype(cat, cat));
|
||||||
|
switch (cat) {
|
||||||
|
case Dtype::Category::complexfloating:
|
||||||
|
CHECK(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::floating:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK(issubdtype(cat, floating));
|
||||||
|
CHECK(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::inexact:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::signedinteger:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, inexact));
|
||||||
|
CHECK(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::unsignedinteger:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::integer:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::number:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, integer));
|
||||||
|
CHECK(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
case Dtype::Category::generic:
|
||||||
|
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, floating));
|
||||||
|
CHECK_FALSE(issubdtype(cat, inexact));
|
||||||
|
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||||
|
CHECK_FALSE(issubdtype(cat, integer));
|
||||||
|
CHECK_FALSE(issubdtype(cat, number));
|
||||||
|
CHECK(issubdtype(cat, generic));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test atleast_1d") {
|
TEST_CASE("test atleast_1d") {
|
||||||
auto x = array(1);
|
auto x = array(1);
|
||||||
auto out = atleast_1d(x);
|
auto out = atleast_1d(x);
|
||||||
|
Loading…
Reference in New Issue
Block a user