mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285.
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
bfb5bad4f0
commit
479051ce1c
@ -58,6 +58,7 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
|
@ -19,7 +19,6 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
@ -32,7 +31,6 @@ Array
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.dtype
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
|
@ -1,7 +1,5 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
|
@ -30,6 +30,7 @@ Module
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.set_dtype
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
|
@ -62,10 +62,10 @@ Operations
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
isnan
|
||||
isposinf
|
||||
isneginf
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
|
@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -355,7 +355,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
}
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
|
@ -22,7 +22,7 @@ namespace mlx::core {
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (is_unsigned(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -362,7 +362,7 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@ -388,7 +388,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
@ -410,7 +410,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -597,7 +597,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@ -608,7 +608,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -630,7 +630,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -642,7 +642,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -850,7 +850,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -862,7 +862,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@ -222,7 +222,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
@ -232,7 +232,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
|
@ -488,7 +488,7 @@ void steel_matmul(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
|
@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
@ -12,6 +12,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
constexpr int num_types = 13;
|
||||
constexpr int num_cats = 8;
|
||||
|
||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
Dtype::Kind::b, // bool_,
|
||||
@ -49,6 +50,35 @@ constexpr Dtype type_rules[num_types][num_types] = {
|
||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||
};
|
||||
|
||||
|
||||
constexpr bool subcategory_to_category[num_cats][num_cats] = {
|
||||
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
|
||||
{true, false, true, false, false, false, true, true}, // complexfloating
|
||||
{false, true, true, false, false, false, true, true}, // floating
|
||||
{false, false, true, false, false, false, true, true}, // inexact
|
||||
{false, false, false, true, false, true, true, true}, // signedinteger
|
||||
{false, false, false, false, true, true, true, true}, // unsignedinteger
|
||||
{false, false, false, false, false, true, true, true}, // integer
|
||||
{false, false, false, false, false, false, true, true}, // number
|
||||
{false, false, false, false, false, false, false, true}, // generic
|
||||
};
|
||||
|
||||
constexpr Dtype::Category type_to_category[num_types] = {
|
||||
Dtype::Category::generic, // bool_,
|
||||
Dtype::Category::unsignedinteger, // uint8,
|
||||
Dtype::Category::unsignedinteger, // uint16,
|
||||
Dtype::Category::unsignedinteger, // uint32,
|
||||
Dtype::Category::unsignedinteger, // uint64,
|
||||
Dtype::Category::signedinteger, // int8,
|
||||
Dtype::Category::signedinteger, // int16,
|
||||
Dtype::Category::signedinteger, // int32,
|
||||
Dtype::Category::signedinteger, // int64,
|
||||
Dtype::Category::floating, // float16,
|
||||
Dtype::Category::floating, // float32,
|
||||
Dtype::Category::floating, // bfloat16,
|
||||
Dtype::Category::complexfloating, // complex64,
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
inline bool is_big_endian() {
|
||||
@ -141,6 +171,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
|
||||
return complex64;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
return subcategory_to_category[static_cast<uint32_t>(a)]
|
||||
[static_cast<uint32_t>(b)];
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
|
46
mlx/dtype.h
46
mlx/dtype.h
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -38,6 +38,17 @@ struct Dtype {
|
||||
V, /* void - used for brain float */
|
||||
};
|
||||
|
||||
enum class Category {
|
||||
complexfloating,
|
||||
floating,
|
||||
inexact,
|
||||
signedinteger,
|
||||
unsignedinteger,
|
||||
integer,
|
||||
number,
|
||||
generic
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||
@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
|
||||
inline constexpr Dtype::Category complexfloating =
|
||||
Dtype::Category::complexfloating;
|
||||
inline constexpr Dtype::Category floating = Dtype::Category::floating;
|
||||
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
|
||||
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
|
||||
inline constexpr Dtype::Category unsignedinteger =
|
||||
Dtype::Category::unsignedinteger;
|
||||
inline constexpr Dtype::Category integer = Dtype::Category::integer;
|
||||
inline constexpr Dtype::Category number = Dtype::Category::number;
|
||||
inline constexpr Dtype::Category generic = Dtype::Category::generic;
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype& a, const Dtype::Category& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
inline uint8_t size_of(const Dtype& t) {
|
||||
@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
||||
inline bool is_unsigned(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
||||
}
|
||||
|
||||
inline bool is_floating_point(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
||||
kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_complex(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_integral(const Dtype& t) {
|
||||
return !(is_floating_point(t));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TypeToDtype {
|
||||
operator Dtype();
|
||||
|
@ -64,7 +64,7 @@ array rms_norm(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto out_type = result_type(x, weight);
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@ -128,7 +128,7 @@ array layer_norm(
|
||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||
: result_type(x, *weight))
|
||||
: x.dtype();
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@ -319,7 +319,7 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
|
||||
auto final_type = result_type(queries, keys, values);
|
||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||
if (!issubdtype(final_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||
<< final_type << ".";
|
||||
|
@ -11,7 +11,7 @@
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array l2_norm(
|
||||
@ -19,7 +19,7 @@ inline array l2_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (is_complex(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), complexfloating)) {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
} else {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
|
26
mlx/ops.cpp
26
mlx/ops.cpp
@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -1140,7 +1140,7 @@ array array_equal(
|
||||
return array(false);
|
||||
} else {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
equal_nan &= is_floating_point(dtype);
|
||||
equal_nan &= issubdtype(dtype, inexact);
|
||||
return all(
|
||||
array(
|
||||
a.shape(),
|
||||
@ -1153,7 +1153,7 @@ array array_equal(
|
||||
}
|
||||
|
||||
array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return not_equal(a, a, s);
|
||||
@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
@ -1929,7 +1929,7 @@ array floor_divide(
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_floating_point(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
|
||||
std::vector<array>
|
||||
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
if (is_complex(dtype)) {
|
||||
if (issubdtype(dtype, complexfloating)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs =
|
||||
@ -2220,7 +2220,7 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
@ -2330,7 +2330,7 @@ array gather(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@ -2521,7 +2521,7 @@ array scatter(
|
||||
|
||||
// Promote indices to the same type
|
||||
auto dtype = result_type(indices);
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[scatter] Got indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
|
||||
}
|
||||
|
||||
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
|
||||
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
|
||||
if (!issubdtype(in.dtype(), floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with type " << in.dtype() << "."
|
||||
<< " Convolution currently only supports floating point types";
|
||||
@ -3062,7 +3062,7 @@ array quantized_matmul(
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
@ -3364,7 +3364,7 @@ array addmm(
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
||||
|
@ -97,7 +97,7 @@ array uniform(
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_floating_point(dtype) && !is_complex(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"Can only generate uniform numbers with real floating point type.");
|
||||
}
|
||||
@ -179,7 +179,7 @@ array randint(
|
||||
Dtype dtype /* = int32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[randint] randint only accepts integer dtypes and bool.");
|
||||
}
|
||||
@ -192,7 +192,7 @@ array bernoulli(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_floating_point(p.dtype())) {
|
||||
if (!issubdtype(p.dtype(), floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||
}
|
||||
@ -228,7 +228,7 @@ array truncated_normal(
|
||||
// Same as
|
||||
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
|
||||
|
||||
if (!is_floating_point(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
|
||||
}
|
||||
|
@ -578,3 +578,26 @@ class Module(dict):
|
||||
See :func:`train`.
|
||||
"""
|
||||
self.train(False)
|
||||
|
||||
def set_dtype(
|
||||
self,
|
||||
dtype: mx.Dtype,
|
||||
predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(
|
||||
x, mx.floating
|
||||
),
|
||||
):
|
||||
"""Set the dtype of the module's parameters.
|
||||
|
||||
Args:
|
||||
dtype (Dtype): The new dtype.
|
||||
predicate (typing.Callable, optional): A predicate to select
|
||||
parameters to cast. By default, only parameters of type
|
||||
:attr:`floating` will be updated to avoid casting integer
|
||||
parameters to the new dtype.
|
||||
"""
|
||||
if predicate is None:
|
||||
|
||||
def predicate(_):
|
||||
return True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
@ -254,7 +254,7 @@ array array_from_list(
|
||||
std::vector<uint32_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (is_floating_point(dtype)) {
|
||||
} else if (issubdtype(dtype, inexact)) {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
@ -439,6 +439,54 @@ void init_array(nb::module_& m) {
|
||||
m.attr("float32") = nb::cast(float32);
|
||||
m.attr("bfloat16") = nb::cast(bfloat16);
|
||||
m.attr("complex64") = nb::cast(complex64);
|
||||
nb::class_<Dtype::Category>(
|
||||
m,
|
||||
"DtypeCategory",
|
||||
R"pbdoc(
|
||||
Type to hold categories of :class:`dtypes <Dtype>`.
|
||||
|
||||
* :attr:`~mlx.core.generic`
|
||||
|
||||
* :ref:`bool_ <data_types>`
|
||||
* :attr:`~mlx.core.number`
|
||||
|
||||
* :attr:`~mlx.core.integer`
|
||||
|
||||
* :attr:`~mlx.core.unsignedinteger`
|
||||
|
||||
* :ref:`uint8 <data_types>`
|
||||
* :ref:`uint16 <data_types>`
|
||||
* :ref:`uint32 <data_types>`
|
||||
* :ref:`uint64 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.signedinteger`
|
||||
|
||||
* :ref:`int8 <data_types>`
|
||||
* :ref:`int32 <data_types>`
|
||||
* :ref:`int64 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.inexact`
|
||||
|
||||
* :attr:`~mlx.core.floating`
|
||||
|
||||
* :ref:`float16 <data_types>`
|
||||
* :ref:`bfloat16 <data_types>`
|
||||
* :ref:`float32 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.complexfloating`
|
||||
|
||||
* :ref:`complex128 <data_types>`
|
||||
|
||||
See also :func:`~mlx.core.issubdtype`.
|
||||
)pbdoc");
|
||||
m.attr("complexfloating") = nb::cast(complexfloating);
|
||||
m.attr("floating") = nb::cast(floating);
|
||||
m.attr("inexact") = nb::cast(inexact);
|
||||
m.attr("signedinteger") = nb::cast(signedinteger);
|
||||
m.attr("unsignedinteger") = nb::cast(unsignedinteger);
|
||||
m.attr("integer") = nb::cast(integer);
|
||||
m.attr("number") = nb::cast(number);
|
||||
m.attr("generic") = nb::cast(generic);
|
||||
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
@ -700,7 +748,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__itruediv__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_floating_point(a.dtype())) {
|
||||
if (!issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"In place division cannot cast to non-floating point type.");
|
||||
}
|
||||
@ -852,7 +900,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__invert__",
|
||||
[](const array& a) {
|
||||
if (is_floating_point(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise inversion.");
|
||||
}
|
||||
@ -866,7 +914,8 @@ void init_array(nb::module_& m) {
|
||||
"__and__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
@ -881,7 +930,8 @@ void init_array(nb::module_& m) {
|
||||
"__iand__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
@ -898,7 +948,8 @@ void init_array(nb::module_& m) {
|
||||
"__or__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
@ -913,7 +964,8 @@ void init_array(nb::module_& m) {
|
||||
"__ior__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
|
@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a,
|
||||
R"pbdoc(
|
||||
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
|
||||
of another.
|
||||
|
||||
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
|
||||
>>> mx.issubdtype(ints.dtype, mx.integer)
|
||||
True
|
||||
>>> mx.issubdtype(ints.dtype, mx.floating)
|
||||
False
|
||||
|
||||
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
|
||||
>>> mx.issubdtype(floats.dtype, mx.integer)
|
||||
False
|
||||
>>> mx.issubdtype(floats.dtype, mx.floating)
|
||||
True
|
||||
|
||||
Similar types of different sizes are not subdtypes of each other:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.float32)
|
||||
False
|
||||
>>> mx.issubdtype(mx.float32, mx.float64)
|
||||
False
|
||||
|
||||
but both are subtypes of `floating`:
|
||||
|
||||
>>> mx.issubdtype(mx.float64, mx.floating)
|
||||
True
|
||||
>>> mx.issubdtype(mx.float32, mx.floating)
|
||||
True
|
||||
|
||||
For convenience, dtype-like objects are allowed too:
|
||||
|
||||
>>> mx.issubdtype(mx.float32, mx.inexact)
|
||||
True
|
||||
>>> mx.issubdtype(mx.signedinteger, mx.floating)
|
||||
False
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"issubdtype",
|
||||
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ inline array to_array(
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
|
@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
|
||||
)
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
|
||||
|
||||
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
|
||||
assert_dtype(layer, mx.float32)
|
||||
|
||||
layer.set_dtype(mx.bfloat16)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.float32, lambda x: False)
|
||||
assert_dtype(layer, mx.bfloat16)
|
||||
|
||||
layer.set_dtype(mx.int32, lambda x: True)
|
||||
assert_dtype(layer, mx.int32)
|
||||
|
||||
layer.set_dtype(mx.int64, predicate=None)
|
||||
assert_dtype(layer, mx.int64)
|
||||
|
||||
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
|
||||
assert_dtype(layer, mx.int16)
|
||||
|
||||
def test_rnn(self):
|
||||
layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
|
||||
inp = mx.random.normal((2, 25, 5))
|
||||
|
@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
||||
cats = [
|
||||
"complexfloating",
|
||||
"floating",
|
||||
"inexact",
|
||||
"signedinteger",
|
||||
"unsignedinteger",
|
||||
"integer",
|
||||
"number",
|
||||
"generic",
|
||||
"bool_",
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float16",
|
||||
"float32",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
for a in cats:
|
||||
for b in cats:
|
||||
self.assertEqual(
|
||||
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
|
||||
np.issubdtype(getattr(np, a), getattr(np, b)),
|
||||
f"mx and np don't aggree on {a}, {b}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2848,6 +2848,192 @@ TEST_CASE("test diag") {
|
||||
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test issubdtype") {
|
||||
const auto cats = {
|
||||
complexfloating,
|
||||
floating,
|
||||
inexact,
|
||||
signedinteger,
|
||||
unsignedinteger,
|
||||
integer,
|
||||
number,
|
||||
generic};
|
||||
const auto types = {
|
||||
bool_,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
bfloat16,
|
||||
complex64};
|
||||
for (const auto& type : types) {
|
||||
CHECK(issubdtype(type, type));
|
||||
CHECK(issubdtype(type, generic));
|
||||
switch (kindof(type)) {
|
||||
case Dtype::Kind::b:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK_FALSE(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::u:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK(issubdtype(type, unsignedinteger));
|
||||
CHECK(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::i:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK_FALSE(issubdtype(type, inexact));
|
||||
CHECK(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::f:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::c:
|
||||
CHECK(issubdtype(type, complexfloating));
|
||||
CHECK_FALSE(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
case Dtype::Kind::V:
|
||||
CHECK_FALSE(issubdtype(type, complexfloating));
|
||||
CHECK(issubdtype(type, floating));
|
||||
CHECK(issubdtype(type, inexact));
|
||||
CHECK_FALSE(issubdtype(type, signedinteger));
|
||||
CHECK_FALSE(issubdtype(type, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(type, integer));
|
||||
CHECK(issubdtype(type, number));
|
||||
CHECK(issubdtype(type, generic));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& type : types) {
|
||||
CHECK(issubdtype(type, type));
|
||||
CHECK(issubdtype(type, generic));
|
||||
for (auto type1 : types) {
|
||||
CHECK_EQ(issubdtype(type, type1), type == type1);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& cat : cats) {
|
||||
CHECK(issubdtype(cat, cat));
|
||||
switch (cat) {
|
||||
case Dtype::Category::complexfloating:
|
||||
CHECK(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::floating:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::inexact:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::signedinteger:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::unsignedinteger:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::integer:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::number:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
case Dtype::Category::generic:
|
||||
CHECK_FALSE(issubdtype(cat, complexfloating));
|
||||
CHECK_FALSE(issubdtype(cat, floating));
|
||||
CHECK_FALSE(issubdtype(cat, inexact));
|
||||
CHECK_FALSE(issubdtype(cat, signedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, unsignedinteger));
|
||||
CHECK_FALSE(issubdtype(cat, integer));
|
||||
CHECK_FALSE(issubdtype(cat, number));
|
||||
CHECK(issubdtype(cat, generic));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_1d(x);
|
||||
|
Loading…
Reference in New Issue
Block a user