add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch 2024-03-25 20:32:59 +01:00 committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 538 additions and 97 deletions

View File

@ -58,6 +58,7 @@ are the CPU and GPU.
:maxdepth: 1 :maxdepth: 1
python/array python/array
python/data_types
python/devices_and_streams python/devices_and_streams
python/ops python/ops
python/random python/random

View File

@ -19,7 +19,6 @@ Array
array.ndim array.ndim
array.shape array.shape
array.size array.size
Dtype
array.abs array.abs
array.all array.all
array.any array.any
@ -32,7 +31,6 @@ Array
array.cumsum array.cumsum
array.diag array.diag
array.diagonal array.diagonal
array.dtype
array.exp array.exp
array.flatten array.flatten
array.log array.log

View File

@ -1,7 +1,5 @@
.. _data_types: .. _data_types:
:orphan:
Data Types Data Types
========== ==========
@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is
* - ``complex64`` * - ``complex64``
- 8 - 8
- 64-bit complex float - 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@ -30,6 +30,7 @@ Module
Module.named_modules Module.named_modules
Module.parameters Module.parameters
Module.save_weights Module.save_weights
Module.set_dtype
Module.train Module.train
Module.trainable_parameters Module.trainable_parameters
Module.unfreeze Module.unfreeze

View File

@ -62,10 +62,10 @@ Operations
identity identity
inner inner
isclose isclose
isnan
isposinf
isneginf
isinf isinf
isnan
isneginf
isposinf
less less
less_equal less_equal
linspace linspace

View File

@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); }); unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -355,7 +355,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); }); unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
if (is_floating_point(out.dtype())) { if (out.dtype() == float32) {
if (out.dtype() == float32) { binary_op<float>(a, b, out, detail::LogAddExp());
binary_op<float>(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) {
} else if (out.dtype() == float16) { binary_op<float16_t>(a, b, out, detail::LogAddExp());
binary_op<float16_t>(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) {
} else if (out.dtype() == bfloat16) { binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp()); } else if (issubdtype(out.dtype(), inexact)) {
} else { std::ostringstream err;
std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype();
err << "[logaddexp] Does not support " << out.dtype(); throw std::invalid_argument(err.str());
throw std::invalid_argument(err.str());
}
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with" "[logaddexp] Cannot compute logaddexp for arrays with"

View File

@ -22,7 +22,7 @@ namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) { void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (is_unsigned(in.dtype())) { if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
void ArcCos::eval(const std::vector<array>& inputs, array& out) { void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos()); unary_fp(in, out, detail::ArcCos());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
void ArcCosh::eval(const std::vector<array>& inputs, array& out) { void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh()); unary_fp(in, out, detail::ArcCosh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
void ArcSin::eval(const std::vector<array>& inputs, array& out) { void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin()); unary_fp(in, out, detail::ArcSin());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
void ArcSinh::eval(const std::vector<array>& inputs, array& out) { void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh()); unary_fp(in, out, detail::ArcSinh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
void ArcTan::eval(const std::vector<array>& inputs, array& out) { void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan()); unary_fp(in, out, detail::ArcTan());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
void ArcTanh::eval(const std::vector<array>& inputs, array& out) { void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh()); unary_fp(in, out, detail::ArcTanh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
void Ceil::eval(const std::vector<array>& inputs, array& out) { void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil()); unary_fp(in, out, detail::Ceil());
} else { } else {
// No-op integer types // No-op integer types
@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos()); unary_fp(in, out, detail::Cos());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
void Cosh::eval(const std::vector<array>& inputs, array& out) { void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh()); unary_fp(in, out, detail::Cosh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) { void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp()); unary_fp(in, out, detail::Exp());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -362,7 +362,7 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
void Floor::eval(const std::vector<array>& inputs, array& out) { void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor()); unary_fp(in, out, detail::Floor());
} else { } else {
// No-op integer types // No-op integer types
@ -388,7 +388,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
void Log::eval(const std::vector<array>& inputs, array& out) { void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
switch (base_) { switch (base_) {
case Base::e: case Base::e:
unary_fp(in, out, detail::Log()); unary_fp(in, out, detail::Log());
@ -410,7 +410,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
void Log1p::eval(const std::vector<array>& inputs, array& out) { void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p()); unary_fp(in, out, detail::Log1p());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -597,7 +597,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
void Round::eval(const std::vector<array>& inputs, array& out) { void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round()); unary_fp(in, out, detail::Round());
} else { } else {
// No-op integer types // No-op integer types
@ -608,7 +608,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval(const std::vector<array>& inputs, array& out) { void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid()); unary_fp(in, out, detail::Sigmoid());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -630,7 +630,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
void Sin::eval(const std::vector<array>& inputs, array& out) { void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin()); unary_fp(in, out, detail::Sin());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -642,7 +642,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
void Sinh::eval(const std::vector<array>& inputs, array& out) { void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh()); unary_fp(in, out, detail::Sinh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -850,7 +850,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Tan::eval(const std::vector<array>& inputs, array& out) { void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan()); unary_fp(in, out, detail::Tan());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -862,7 +862,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
void Tanh::eval(const std::vector<array>& inputs, array& out) { void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (is_floating_point(out.dtype())) { if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh()); unary_fp(in, out, detail::Tanh());
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -222,7 +222,7 @@ void scan_dispatch(
} }
case Scan::Min: { case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; }; auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (is_floating_point(input.dtype())) auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity()) ? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
@ -232,7 +232,7 @@ void scan_dispatch(
} }
case Scan::Max: { case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (is_floating_point(input.dtype())) auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity()) ? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max(); : std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);

View File

@ -488,7 +488,7 @@ void steel_matmul(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) { if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[matmul] Does not yet support non-floating point types.");
} }
@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) { if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[matmul] Does not yet support non-floating point types.");
} }

View File

@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) { void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (not is_integral(in.dtype())) { if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round"); unary_op(inputs, out, "round");
} else { } else {
// No-op integer types // No-op integer types

View File

@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
assert(inputs.size() >= 3); assert(inputs.size() >= 3);
if (!is_floating_point(out.dtype())) { if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error( throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types."); "[ScaledDotProductAttention] Does not yet support non-floating point types.");
} }

View File

@ -12,7 +12,7 @@ namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) { if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error( throw std::runtime_error(
"[softmax] Does not support non-floating point types."); "[softmax] Does not support non-floating point types.");
} }

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdint> #include <cstdint>
#include <sstream> #include <sstream>
@ -12,6 +12,7 @@ namespace mlx::core {
namespace { namespace {
constexpr int num_types = 13; constexpr int num_types = 13;
constexpr int num_cats = 8;
constexpr Dtype::Kind type_kinds[num_types] = { constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::b, // bool_, Dtype::Kind::b, // bool_,
@ -49,6 +50,35 @@ constexpr Dtype type_rules[num_types][num_types] = {
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64 {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
}; };
constexpr bool subcategory_to_category[num_cats][num_cats] = {
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
{true, false, true, false, false, false, true, true}, // complexfloating
{false, true, true, false, false, false, true, true}, // floating
{false, false, true, false, false, false, true, true}, // inexact
{false, false, false, true, false, true, true, true}, // signedinteger
{false, false, false, false, true, true, true, true}, // unsignedinteger
{false, false, false, false, false, true, true, true}, // integer
{false, false, false, false, false, false, true, true}, // number
{false, false, false, false, false, false, false, true}, // generic
};
constexpr Dtype::Category type_to_category[num_types] = {
Dtype::Category::generic, // bool_,
Dtype::Category::unsignedinteger, // uint8,
Dtype::Category::unsignedinteger, // uint16,
Dtype::Category::unsignedinteger, // uint32,
Dtype::Category::unsignedinteger, // uint64,
Dtype::Category::signedinteger, // int8,
Dtype::Category::signedinteger, // int16,
Dtype::Category::signedinteger, // int32,
Dtype::Category::signedinteger, // int64,
Dtype::Category::floating, // float16,
Dtype::Category::floating, // float32,
Dtype::Category::floating, // bfloat16,
Dtype::Category::complexfloating, // complex64,
};
// clang-format on // clang-format on
inline bool is_big_endian() { inline bool is_big_endian() {
@ -141,6 +171,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
return complex64; return complex64;
} }
bool issubdtype(const Dtype& a, const Dtype& b) {
return a == b;
}
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
return false;
}
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
}
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
return subcategory_to_category[static_cast<uint32_t>(a)]
[static_cast<uint32_t>(b)];
}
// Array protocol typestring for Dtype // Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) { std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r; std::ostringstream r;

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -38,6 +38,17 @@ struct Dtype {
V, /* void - used for brain float */ V, /* void - used for brain float */
}; };
enum class Category {
complexfloating,
floating,
inexact,
signedinteger,
unsignedinteger,
integer,
number,
generic
};
Val val; Val val;
const uint8_t size; const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
inline constexpr Dtype::Category complexfloating =
Dtype::Category::complexfloating;
inline constexpr Dtype::Category floating = Dtype::Category::floating;
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
inline constexpr Dtype::Category unsignedinteger =
Dtype::Category::unsignedinteger;
inline constexpr Dtype::Category integer = Dtype::Category::integer;
inline constexpr Dtype::Category number = Dtype::Category::number;
inline constexpr Dtype::Category generic = Dtype::Category::generic;
bool issubdtype(const Dtype& a, const Dtype& b);
bool issubdtype(const Dtype::Category& a, const Dtype& b);
bool issubdtype(const Dtype& a, const Dtype::Category& b);
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
Dtype promote_types(const Dtype& t1, const Dtype& t2); Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) { inline uint8_t size_of(const Dtype& t) {
@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
Dtype::Kind kindof(const Dtype& t); Dtype::Kind kindof(const Dtype& t);
inline bool is_unsigned(const Dtype& t) {
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
}
inline bool is_floating_point(const Dtype& t) {
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
kindof(t) == Dtype::Kind::c;
}
inline bool is_complex(const Dtype& t) {
return kindof(t) == Dtype::Kind::c;
}
inline bool is_integral(const Dtype& t) {
return !(is_floating_point(t));
}
template <typename T> template <typename T>
struct TypeToDtype { struct TypeToDtype {
operator Dtype(); operator Dtype();

View File

@ -64,7 +64,7 @@ array rms_norm(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto out_type = result_type(x, weight); auto out_type = result_type(x, weight);
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << "."; msg << "[rms_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
@ -128,7 +128,7 @@ array layer_norm(
? ((bias.has_value()) ? result_type(x, *weight, *bias) ? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight)) : result_type(x, *weight))
: x.dtype(); : x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[layer_norm] Received unsupported type " << out_type << "."; msg << "[layer_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
@ -319,7 +319,7 @@ array scaled_dot_product_attention(
} }
auto final_type = result_type(queries, keys, values); auto final_type = result_type(queries, keys, values);
if (!is_floating_point(final_type) || is_complex(final_type)) { if (!issubdtype(final_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type " msg << "[scaled_dot_product_attention] Received unsupported type "
<< final_type << "."; << final_type << ".";

View File

@ -11,7 +11,7 @@
namespace mlx::core::linalg { namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32); return issubdtype(d, inexact) ? d : promote_types(d, float32);
} }
inline array l2_norm( inline array l2_norm(
@ -19,7 +19,7 @@ inline array l2_norm(
const std::vector<int>& axis, const std::vector<int>& axis,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
if (is_complex(a.dtype())) { if (issubdtype(a.dtype(), complexfloating)) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else { } else {
return sqrt(sum(square(a, s), axis, keepdims, s), s); return sqrt(sum(square(a, s), axis, keepdims, s), s);

View File

@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
} }
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32); return issubdtype(d, inexact) ? d : promote_types(d, float32);
} }
} // namespace } // namespace
@ -1140,7 +1140,7 @@ array array_equal(
return array(false); return array(false);
} else { } else {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
equal_nan &= is_floating_point(dtype); equal_nan &= issubdtype(dtype, inexact);
return all( return all(
array( array(
a.shape(), a.shape(),
@ -1153,7 +1153,7 @@ array array_equal(
} }
array isnan(const array& a, StreamOrDevice s /* = {} */) { array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s); return full(a.shape(), false, bool_, s);
} }
return not_equal(a, a, s); return not_equal(a, a, s);
@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
} }
array isposinf(const array& a, StreamOrDevice s /* = {} */) { array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s); return full(a.shape(), false, bool_, s);
} }
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s); return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
} }
array isneginf(const array& a, StreamOrDevice s /* = {} */) { array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s); return full(a.shape(), false, bool_, s);
} }
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s); return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
@ -1929,7 +1929,7 @@ array floor_divide(
const array& b, const array& b,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) { if (issubdtype(dtype, inexact)) {
return floor(divide(a, b, s), s); return floor(divide(a, b, s), s);
} }
@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
std::vector<array> std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
if (is_complex(dtype)) { if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported."); throw std::invalid_argument("[divmod] Complex type not supported.");
} }
auto inputs = auto inputs =
@ -2220,7 +2220,7 @@ array matmul(
} }
// Type promotion // Type promotion
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but " msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results" << a.dtype() << " and " << b.dtype() << " were provided which results"
@ -2330,7 +2330,7 @@ array gather(
// Promote indices to the same type // Promote indices to the same type
auto dtype = result_type(indices); auto dtype = result_type(indices);
if (!is_integral(dtype)) { if (issubdtype(dtype, inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"[gather] Got indices with invalid dtype. Indices must be integral."); "[gather] Got indices with invalid dtype. Indices must be integral.");
} }
@ -2521,7 +2521,7 @@ array scatter(
// Promote indices to the same type // Promote indices to the same type
auto dtype = result_type(indices); auto dtype = result_type(indices);
if (!is_integral(dtype)) { if (issubdtype(dtype, inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"[scatter] Got indices with invalid dtype. Indices must be integral."); "[scatter] Got indices with invalid dtype. Indices must be integral.");
} }
@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
} }
inline void run_conv_checks(const array& in, const array& wt, int n_dim) { inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) { if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "." msg << "[conv] Invalid input array with type " << in.dtype() << "."
<< " Convolution currently only supports floating point types"; << " Convolution currently only supports floating point types";
@ -3062,7 +3062,7 @@ array quantized_matmul(
} }
auto dtype = result_type(x, scales, biases); auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) { if (!issubdtype(dtype, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but " msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype() << "the passed types where x.dtype() == " << x.dtype()
@ -3364,7 +3364,7 @@ array addmm(
// Type promotion // Type promotion
auto out_type = result_type(a, b, c); auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but " msg << "[addmm] Only real floating point types are supported but "
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype() << c.dtype() << ", " << a.dtype() << " and " << b.dtype()

View File

@ -97,7 +97,7 @@ array uniform(
Dtype dtype /* = float32 */, Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!is_floating_point(dtype) && !is_complex(dtype)) { if (!issubdtype(dtype, floating)) {
throw std::invalid_argument( throw std::invalid_argument(
"Can only generate uniform numbers with real floating point type."); "Can only generate uniform numbers with real floating point type.");
} }
@ -179,7 +179,7 @@ array randint(
Dtype dtype /* = int32 */, Dtype dtype /* = int32 */,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!is_integral(dtype)) { if (issubdtype(dtype, inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"[randint] randint only accepts integer dtypes and bool."); "[randint] randint only accepts integer dtypes and bool.");
} }
@ -192,7 +192,7 @@ array bernoulli(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::optional<array>& key /*= nullopt */, const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!is_floating_point(p.dtype())) { if (!issubdtype(p.dtype(), floating)) {
throw std::invalid_argument( throw std::invalid_argument(
"[bernoulli] bernoulli probability `p` must be a float type."); "[bernoulli] bernoulli probability `p` must be a float type.");
} }
@ -228,7 +228,7 @@ array truncated_normal(
// Same as // Same as
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal // https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
if (!is_floating_point(dtype)) { if (!issubdtype(dtype, floating)) {
throw std::invalid_argument( throw std::invalid_argument(
"[trunc_normal] trunc_normal only accepts floating point dtypes."); "[trunc_normal] trunc_normal only accepts floating point dtypes.");
} }

View File

@ -578,3 +578,26 @@ class Module(dict):
See :func:`train`. See :func:`train`.
""" """
self.train(False) self.train(False)
def set_dtype(
self,
dtype: mx.Dtype,
predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype(
x, mx.floating
),
):
"""Set the dtype of the module's parameters.
Args:
dtype (Dtype): The new dtype.
predicate (typing.Callable, optional): A predicate to select
parameters to cast. By default, only parameters of type
:attr:`floating` will be updated to avoid casting integer
parameters to the new dtype.
"""
if predicate is None:
def predicate(_):
return True
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)

View File

@ -254,7 +254,7 @@ array array_from_list(
std::vector<uint32_t> vals; std::vector<uint32_t> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return array(vals.begin(), shape, dtype);
} else if (is_floating_point(dtype)) { } else if (issubdtype(dtype, inexact)) {
std::vector<float> vals; std::vector<float> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return array(vals.begin(), shape, dtype);
@ -439,6 +439,54 @@ void init_array(nb::module_& m) {
m.attr("float32") = nb::cast(float32); m.attr("float32") = nb::cast(float32);
m.attr("bfloat16") = nb::cast(bfloat16); m.attr("bfloat16") = nb::cast(bfloat16);
m.attr("complex64") = nb::cast(complex64); m.attr("complex64") = nb::cast(complex64);
nb::class_<Dtype::Category>(
m,
"DtypeCategory",
R"pbdoc(
Type to hold categories of :class:`dtypes <Dtype>`.
* :attr:`~mlx.core.generic`
* :ref:`bool_ <data_types>`
* :attr:`~mlx.core.number`
* :attr:`~mlx.core.integer`
* :attr:`~mlx.core.unsignedinteger`
* :ref:`uint8 <data_types>`
* :ref:`uint16 <data_types>`
* :ref:`uint32 <data_types>`
* :ref:`uint64 <data_types>`
* :attr:`~mlx.core.signedinteger`
* :ref:`int8 <data_types>`
* :ref:`int32 <data_types>`
* :ref:`int64 <data_types>`
* :attr:`~mlx.core.inexact`
* :attr:`~mlx.core.floating`
* :ref:`float16 <data_types>`
* :ref:`bfloat16 <data_types>`
* :ref:`float32 <data_types>`
* :attr:`~mlx.core.complexfloating`
* :ref:`complex128 <data_types>`
See also :func:`~mlx.core.issubdtype`.
)pbdoc");
m.attr("complexfloating") = nb::cast(complexfloating);
m.attr("floating") = nb::cast(floating);
m.attr("inexact") = nb::cast(inexact);
m.attr("signedinteger") = nb::cast(signedinteger);
m.attr("unsignedinteger") = nb::cast(unsignedinteger);
m.attr("integer") = nb::cast(integer);
m.attr("number") = nb::cast(number);
m.attr("generic") = nb::cast(generic);
nb::class_<ArrayAt>( nb::class_<ArrayAt>(
m, m,
@ -700,7 +748,7 @@ void init_array(nb::module_& m) {
.def( .def(
"__itruediv__", "__itruediv__",
[](array& a, const ScalarOrArray v) -> array& { [](array& a, const ScalarOrArray v) -> array& {
if (!is_floating_point(a.dtype())) { if (!issubdtype(a.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"In place division cannot cast to non-floating point type."); "In place division cannot cast to non-floating point type.");
} }
@ -852,7 +900,7 @@ void init_array(nb::module_& m) {
.def( .def(
"__invert__", "__invert__",
[](const array& a) { [](const array& a) {
if (is_floating_point(a.dtype())) { if (issubdtype(a.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise inversion."); "Floating point types not allowed with or bitwise inversion.");
} }
@ -866,7 +914,8 @@ void init_array(nb::module_& m) {
"__and__", "__and__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with bitwise and."); "Floating point types not allowed with bitwise and.");
} }
@ -881,7 +930,8 @@ void init_array(nb::module_& m) {
"__iand__", "__iand__",
[](array& a, const ScalarOrArray v) -> array& { [](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with bitwise and."); "Floating point types not allowed with bitwise and.");
} }
@ -898,7 +948,8 @@ void init_array(nb::module_& m) {
"__or__", "__or__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise or."); "Floating point types not allowed with or bitwise or.");
} }
@ -913,7 +964,8 @@ void init_array(nb::module_& m) {
"__ior__", "__ior__",
[](array& a, const ScalarOrArray v) -> array& { [](array& a, const ScalarOrArray v) -> array& {
auto b = to_array(v, a.dtype()); auto b = to_array(v, a.dtype());
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument( throw std::invalid_argument(
"Floating point types not allowed with or bitwise or."); "Floating point types not allowed with or bitwise or.");
} }

View File

@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array or list(array): An array or list of arrays with at least three dimensions. array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc"); )pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype&>(&issubdtype),
""_a,
""_a,
R"pbdoc(
Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype
of another.
>>> ints = mx.array([1, 2, 3], dtype=mx.int32)
>>> mx.issubdtype(ints.dtype, mx.integer)
True
>>> mx.issubdtype(ints.dtype, mx.floating)
False
>>> floats = mx.array([1, 2, 3], dtype=mx.float32)
>>> mx.issubdtype(floats.dtype, mx.integer)
False
>>> mx.issubdtype(floats.dtype, mx.floating)
True
Similar types of different sizes are not subdtypes of each other:
>>> mx.issubdtype(mx.float64, mx.float32)
False
>>> mx.issubdtype(mx.float32, mx.float64)
False
but both are subtypes of `floating`:
>>> mx.issubdtype(mx.float64, mx.floating)
True
>>> mx.issubdtype(mx.float32, mx.floating)
True
For convenience, dtype-like objects are allowed too:
>>> mx.issubdtype(mx.float32, mx.inexact)
True
>>> mx.issubdtype(mx.signedinteger, mx.floating)
False
)pbdoc");
m.def(
"issubdtype",
nb::overload_cast<const Dtype&, const Dtype::Category&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype&>(&issubdtype),
""_a,
""_a);
m.def(
"issubdtype",
nb::overload_cast<const Dtype::Category&, const Dtype::Category&>(
&issubdtype),
""_a,
""_a);
} }

View File

@ -56,7 +56,7 @@ inline array to_array(
} else if (auto pv = std::get_if<nb::float_>(&v); pv) { } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32); auto out_t = dtype.value_or(float32);
return array( return array(
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32); nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64); return array(static_cast<complex64_t>(*pv), complex64);
} else { } else {

View File

@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase):
"AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))",
) )
def test_set_dtype(self):
def assert_dtype(layer, dtype):
for k, v in tree_flatten(layer.parameters()):
self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}")
layer = nn.Linear(input_dims=4, output_dims=8, bias=True)
assert_dtype(layer, mx.float32)
layer.set_dtype(mx.bfloat16)
assert_dtype(layer, mx.bfloat16)
layer.set_dtype(mx.float32, lambda x: False)
assert_dtype(layer, mx.bfloat16)
layer.set_dtype(mx.int32, lambda x: True)
assert_dtype(layer, mx.int32)
layer.set_dtype(mx.int64, predicate=None)
assert_dtype(layer, mx.int64)
layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer))
assert_dtype(layer, mx.int16)
def test_rnn(self): def test_rnn(self):
layer = nn.RNN(input_size=5, hidden_size=12, bias=True) layer = nn.RNN(input_size=5, hidden_size=12, bias=True)
inp = mx.random.normal((2, 25, 5)) inp = mx.random.normal((2, 25, 5))

View File

@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx_res.ndim, np_res.ndim) self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
cats = [
"complexfloating",
"floating",
"inexact",
"signedinteger",
"unsignedinteger",
"integer",
"number",
"generic",
"bool_",
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"complex64",
]
for a in cats:
for b in cats:
self.assertEqual(
mx.issubdtype(getattr(mx, a), getattr(mx, b)),
np.issubdtype(getattr(np, a), getattr(np, b)),
f"mx and np don't aggree on {a}, {b}",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -2848,6 +2848,192 @@ TEST_CASE("test diag") {
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>()); CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
} }
TEST_CASE("test issubdtype") {
const auto cats = {
complexfloating,
floating,
inexact,
signedinteger,
unsignedinteger,
integer,
number,
generic};
const auto types = {
bool_,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
float16,
float32,
bfloat16,
complex64};
for (const auto& type : types) {
CHECK(issubdtype(type, type));
CHECK(issubdtype(type, generic));
switch (kindof(type)) {
case Dtype::Kind::b:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK_FALSE(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::u:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK(issubdtype(type, unsignedinteger));
CHECK(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::i:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK_FALSE(issubdtype(type, inexact));
CHECK(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::f:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::c:
CHECK(issubdtype(type, complexfloating));
CHECK_FALSE(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
case Dtype::Kind::V:
CHECK_FALSE(issubdtype(type, complexfloating));
CHECK(issubdtype(type, floating));
CHECK(issubdtype(type, inexact));
CHECK_FALSE(issubdtype(type, signedinteger));
CHECK_FALSE(issubdtype(type, unsignedinteger));
CHECK_FALSE(issubdtype(type, integer));
CHECK(issubdtype(type, number));
CHECK(issubdtype(type, generic));
break;
}
}
for (const auto& type : types) {
CHECK(issubdtype(type, type));
CHECK(issubdtype(type, generic));
for (auto type1 : types) {
CHECK_EQ(issubdtype(type, type1), type == type1);
}
}
for (const auto& cat : cats) {
CHECK(issubdtype(cat, cat));
switch (cat) {
case Dtype::Category::complexfloating:
CHECK(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::floating:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::inexact:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::signedinteger:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::unsignedinteger:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::integer:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::number:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
case Dtype::Category::generic:
CHECK_FALSE(issubdtype(cat, complexfloating));
CHECK_FALSE(issubdtype(cat, floating));
CHECK_FALSE(issubdtype(cat, inexact));
CHECK_FALSE(issubdtype(cat, signedinteger));
CHECK_FALSE(issubdtype(cat, unsignedinteger));
CHECK_FALSE(issubdtype(cat, integer));
CHECK_FALSE(issubdtype(cat, number));
CHECK(issubdtype(cat, generic));
break;
}
}
}
TEST_CASE("test atleast_1d") { TEST_CASE("test atleast_1d") {
auto x = array(1); auto x = array(1);
auto out = atleast_1d(x); auto out = atleast_1d(x);