diff --git a/docs/src/index.rst b/docs/src/index.rst index aec2ea0b8..a9ec3899f 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -58,6 +58,7 @@ are the CPU and GPU. :maxdepth: 1 python/array + python/data_types python/devices_and_streams python/ops python/random diff --git a/docs/src/python/array.rst b/docs/src/python/array.rst index 00f97c68f..9946f3529 100644 --- a/docs/src/python/array.rst +++ b/docs/src/python/array.rst @@ -19,7 +19,6 @@ Array array.ndim array.shape array.size - Dtype array.abs array.all array.any @@ -32,7 +31,6 @@ Array array.cumsum array.diag array.diagonal - array.dtype array.exp array.flatten array.log diff --git a/docs/src/python/data_types.rst b/docs/src/python/data_types.rst index 83991261e..549446447 100644 --- a/docs/src/python/data_types.rst +++ b/docs/src/python/data_types.rst @@ -1,7 +1,5 @@ .. _data_types: -:orphan: - Data Types ========== @@ -56,3 +54,15 @@ The default floating point type is ``float32`` and the default integer type is * - ``complex64`` - 8 - 64-bit complex float + + +Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object +documentation for more information. Use :func:`issubdtype` to determine if one +``dtype`` (or category) is a subtype of another category. + +.. autosummary:: + :toctree: _autosummary + + Dtype + DtypeCategory + issubdtype diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst index c3a4dfa62..c17f63ece 100644 --- a/docs/src/python/nn/module.rst +++ b/docs/src/python/nn/module.rst @@ -30,6 +30,7 @@ Module Module.named_modules Module.parameters Module.save_weights + Module.set_dtype Module.train Module.trainable_parameters Module.unfreeze diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 462e92a59..a10b126af 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -62,10 +62,10 @@ Operations identity inner isclose - isnan - isposinf - isneginf isinf + isnan + isneginf + isposinf less less_equal linspace diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 5567c0785..4ca3da1b8 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector& inputs, array& out) { set_unary_output_data(in, out); auto size = in.data_size(); vvexpf(out.data(), in.data(), reinterpret_cast(&size)); - } else if (is_floating_point(out.dtype())) { + } else if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, [](auto x) { return std::exp(x); }); } else { throw std::invalid_argument( @@ -355,7 +355,7 @@ void Log1p::eval_cpu(const std::vector& inputs, array& out) { auto size = in.data_size(); vvlog1pf( out.data(), in.data(), reinterpret_cast(&size)); - } else if (is_floating_point(out.dtype())) { + } else if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, [](auto x) { return std::log1p(x); }); } else { throw std::invalid_argument( diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index ec7097797..810062dfd 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - if (is_floating_point(out.dtype())) { - if (out.dtype() == float32) { - binary_op(a, b, out, detail::LogAddExp()); - } else if (out.dtype() == float16) { - binary_op(a, b, out, detail::LogAddExp()); - } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, detail::LogAddExp()); - } else { - std::ostringstream err; - err << "[logaddexp] Does not support " << out.dtype(); - throw std::invalid_argument(err.str()); - } + if (out.dtype() == float32) { + binary_op(a, b, out, detail::LogAddExp()); + } else if (out.dtype() == float16) { + binary_op(a, b, out, detail::LogAddExp()); + } else if (out.dtype() == bfloat16) { + binary_op(a, b, out, detail::LogAddExp()); + } else if (issubdtype(out.dtype(), inexact)) { + std::ostringstream err; + err << "[logaddexp] Does not support " << out.dtype(); + throw std::invalid_argument(err.str()); } else { throw std::invalid_argument( "[logaddexp] Cannot compute logaddexp for arrays with" diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index f8a5e7936..df14ca33b 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -22,7 +22,7 @@ namespace mlx::core { void Abs::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (is_unsigned(in.dtype())) { + if (issubdtype(in.dtype(), unsignedinteger)) { // No-op for unsigned types out.copy_shared_buffer(in); } else { @@ -37,7 +37,7 @@ void Arange::eval(const std::vector& inputs, array& out) { void ArcCos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcCos()); } else { throw std::invalid_argument( @@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector& inputs, array& out) { void ArcCosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcCosh()); } else { throw std::invalid_argument( @@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector& inputs, array& out) { void ArcSin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcSin()); } else { throw std::invalid_argument( @@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector& inputs, array& out) { void ArcSinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcSinh()); } else { throw std::invalid_argument( @@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector& inputs, array& out) { void ArcTan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcTan()); } else { throw std::invalid_argument( @@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector& inputs, array& out) { void ArcTanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::ArcTanh()); } else { throw std::invalid_argument( @@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector& inputs, array& out) { void Ceil::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (not is_integral(in.dtype())) { + if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Ceil()); } else { // No-op integer types @@ -211,7 +211,7 @@ void Copy::eval(const std::vector& inputs, array& out) { void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Cos()); } else { throw std::invalid_argument( @@ -223,7 +223,7 @@ void Cos::eval(const std::vector& inputs, array& out) { void Cosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Cosh()); } else { throw std::invalid_argument( @@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector& inputs, array& out) { void Exp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Exp()); } else { throw std::invalid_argument( @@ -362,7 +362,7 @@ void Exp::eval(const std::vector& inputs, array& out) { void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (not is_integral(in.dtype())) { + if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Floor()); } else { // No-op integer types @@ -388,7 +388,7 @@ void Full::eval(const std::vector& inputs, array& out) { void Log::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { switch (base_) { case Base::e: unary_fp(in, out, detail::Log()); @@ -410,7 +410,7 @@ void Log::eval(const std::vector& inputs, array& out) { void Log1p::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Log1p()); } else { throw std::invalid_argument( @@ -597,7 +597,7 @@ void Reshape::eval(const std::vector& inputs, array& out) { void Round::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (not is_integral(in.dtype())) { + if (issubdtype(in.dtype(), inexact)) { unary_fp(in, out, detail::Round()); } else { // No-op integer types @@ -608,7 +608,7 @@ void Round::eval(const std::vector& inputs, array& out) { void Sigmoid::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Sigmoid()); } else { throw std::invalid_argument( @@ -630,7 +630,7 @@ void Sign::eval(const std::vector& inputs, array& out) { void Sin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Sin()); } else { throw std::invalid_argument( @@ -642,7 +642,7 @@ void Sin::eval(const std::vector& inputs, array& out) { void Sinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Sinh()); } else { throw std::invalid_argument( @@ -850,7 +850,7 @@ void StopGradient::eval(const std::vector& inputs, array& out) { void Tan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Tan()); } else { throw std::invalid_argument( @@ -862,7 +862,7 @@ void Tan::eval(const std::vector& inputs, array& out) { void Tanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { + if (issubdtype(out.dtype(), inexact)) { unary_fp(in, out, detail::Tanh()); } else { throw std::invalid_argument( diff --git a/mlx/backend/common/scan.cpp b/mlx/backend/common/scan.cpp index 9f994b902..221475902 100644 --- a/mlx/backend/common/scan.cpp +++ b/mlx/backend/common/scan.cpp @@ -222,7 +222,7 @@ void scan_dispatch( } case Scan::Min: { auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; }; - auto init = (is_floating_point(input.dtype())) + auto init = (issubdtype(input.dtype(), floating)) ? static_cast(std::numeric_limits::infinity()) : std::numeric_limits::max(); auto opcs = DefaultContiguousScan(op, init); @@ -232,7 +232,7 @@ void scan_dispatch( } case Scan::Max: { auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; - auto init = (is_floating_point(input.dtype())) + auto init = (issubdtype(input.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::max(); auto opcs = DefaultContiguousScan(op, init); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 51033997f..d69ad5998 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -488,7 +488,7 @@ void steel_matmul( void Matmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - if (!is_floating_point(out.dtype())) { + if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } @@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); - if (!is_floating_point(out.dtype())) { + if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index b7eeed0e1..13a991f43 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector& inputs, array& out) { void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (not is_integral(in.dtype())) { + if (issubdtype(in.dtype(), inexact)) { unary_op(inputs, out, "round"); } else { // No-op integer types diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 28f2f162d..55292f092 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { assert(inputs.size() >= 3); - if (!is_floating_point(out.dtype())) { + if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[ScaledDotProductAttention] Does not yet support non-floating point types."); } diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 3a1405e53..a934a9d2d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -12,7 +12,7 @@ namespace mlx::core { void Softmax::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - if (!is_floating_point(out.dtype())) { + if (!issubdtype(out.dtype(), floating)) { throw std::runtime_error( "[softmax] Does not support non-floating point types."); } diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index b4ac9159d..10efb88f9 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -12,6 +12,7 @@ namespace mlx::core { namespace { constexpr int num_types = 13; +constexpr int num_cats = 8; constexpr Dtype::Kind type_kinds[num_types] = { Dtype::Kind::b, // bool_, @@ -49,6 +50,35 @@ constexpr Dtype type_rules[num_types][num_types] = { {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64 }; + +constexpr bool subcategory_to_category[num_cats][num_cats] = { +// complexfloating floating inexact signedinteger unsignedinteger integer number generic + {true, false, true, false, false, false, true, true}, // complexfloating + {false, true, true, false, false, false, true, true}, // floating + {false, false, true, false, false, false, true, true}, // inexact + {false, false, false, true, false, true, true, true}, // signedinteger + {false, false, false, false, true, true, true, true}, // unsignedinteger + {false, false, false, false, false, true, true, true}, // integer + {false, false, false, false, false, false, true, true}, // number + {false, false, false, false, false, false, false, true}, // generic +}; + +constexpr Dtype::Category type_to_category[num_types] = { + Dtype::Category::generic, // bool_, + Dtype::Category::unsignedinteger, // uint8, + Dtype::Category::unsignedinteger, // uint16, + Dtype::Category::unsignedinteger, // uint32, + Dtype::Category::unsignedinteger, // uint64, + Dtype::Category::signedinteger, // int8, + Dtype::Category::signedinteger, // int16, + Dtype::Category::signedinteger, // int32, + Dtype::Category::signedinteger, // int64, + Dtype::Category::floating, // float16, + Dtype::Category::floating, // float32, + Dtype::Category::floating, // bfloat16, + Dtype::Category::complexfloating, // complex64, +}; + // clang-format on inline bool is_big_endian() { @@ -141,6 +171,23 @@ TypeToDtype::operator Dtype() { return complex64; } +bool issubdtype(const Dtype& a, const Dtype& b) { + return a == b; +} + +bool issubdtype(const Dtype::Category& cat, const Dtype& type) { + return false; +} + +bool issubdtype(const Dtype& type, const Dtype::Category& cat) { + return issubdtype(type_to_category[static_cast(type.val)], cat); +} + +bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) { + return subcategory_to_category[static_cast(a)] + [static_cast(b)]; +} + // Array protocol typestring for Dtype std::string dtype_to_array_protocol(const Dtype& t) { std::ostringstream r; diff --git a/mlx/dtype.h b/mlx/dtype.h index cb5f79849..410b70fb1 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -38,6 +38,17 @@ struct Dtype { V, /* void - used for brain float */ }; + enum class Category { + complexfloating, + floating, + inexact, + signedinteger, + unsignedinteger, + integer, + number, + generic + }; + Val val; const uint8_t size; constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; @@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; +inline constexpr Dtype::Category complexfloating = + Dtype::Category::complexfloating; +inline constexpr Dtype::Category floating = Dtype::Category::floating; +inline constexpr Dtype::Category inexact = Dtype::Category::inexact; +inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger; +inline constexpr Dtype::Category unsignedinteger = + Dtype::Category::unsignedinteger; +inline constexpr Dtype::Category integer = Dtype::Category::integer; +inline constexpr Dtype::Category number = Dtype::Category::number; +inline constexpr Dtype::Category generic = Dtype::Category::generic; + +bool issubdtype(const Dtype& a, const Dtype& b); +bool issubdtype(const Dtype::Category& a, const Dtype& b); +bool issubdtype(const Dtype& a, const Dtype::Category& b); +bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); + Dtype promote_types(const Dtype& t1, const Dtype& t2); inline uint8_t size_of(const Dtype& t) { @@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) { Dtype::Kind kindof(const Dtype& t); -inline bool is_unsigned(const Dtype& t) { - return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b; -} - -inline bool is_floating_point(const Dtype& t) { - return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V || - kindof(t) == Dtype::Kind::c; -} - -inline bool is_complex(const Dtype& t) { - return kindof(t) == Dtype::Kind::c; -} - -inline bool is_integral(const Dtype& t) { - return !(is_floating_point(t)); -} - template struct TypeToDtype { operator Dtype(); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 589dbe5aa..96e72b641 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -64,7 +64,7 @@ array rms_norm( throw std::invalid_argument(msg.str()); } auto out_type = result_type(x, weight); - if (!is_floating_point(out_type) || is_complex(out_type)) { + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[rms_norm] Received unsupported type " << out_type << "."; throw std::invalid_argument(msg.str()); @@ -128,7 +128,7 @@ array layer_norm( ? ((bias.has_value()) ? result_type(x, *weight, *bias) : result_type(x, *weight)) : x.dtype(); - if (!is_floating_point(out_type) || is_complex(out_type)) { + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[layer_norm] Received unsupported type " << out_type << "."; throw std::invalid_argument(msg.str()); @@ -319,7 +319,7 @@ array scaled_dot_product_attention( } auto final_type = result_type(queries, keys, values); - if (!is_floating_point(final_type) || is_complex(final_type)) { + if (!issubdtype(final_type, floating)) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Received unsupported type " << final_type << "."; diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index c000d2591..d772c0e14 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -11,7 +11,7 @@ namespace mlx::core::linalg { Dtype at_least_float(const Dtype& d) { - return is_floating_point(d) ? d : promote_types(d, float32); + return issubdtype(d, inexact) ? d : promote_types(d, float32); } inline array l2_norm( @@ -19,7 +19,7 @@ inline array l2_norm( const std::vector& axis, bool keepdims, StreamOrDevice s) { - if (is_complex(a.dtype())) { + if (issubdtype(a.dtype(), complexfloating)) { return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); } else { return sqrt(sum(square(a, s), axis, keepdims, s), s); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c7c858550..ca3d60dd7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -47,7 +47,7 @@ std::pair, std::vector> compute_reduce_shape( } Dtype at_least_float(const Dtype& d) { - return is_floating_point(d) ? d : promote_types(d, float32); + return issubdtype(d, inexact) ? d : promote_types(d, float32); } } // namespace @@ -1140,7 +1140,7 @@ array array_equal( return array(false); } else { auto dtype = promote_types(a.dtype(), b.dtype()); - equal_nan &= is_floating_point(dtype); + equal_nan &= issubdtype(dtype, inexact); return all( array( a.shape(), @@ -1153,7 +1153,7 @@ array array_equal( } array isnan(const array& a, StreamOrDevice s /* = {} */) { - if (is_integral(a.dtype())) { + if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return not_equal(a, a, s); @@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) { } array isposinf(const array& a, StreamOrDevice s /* = {} */) { - if (is_integral(a.dtype())) { + if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } array isneginf(const array& a, StreamOrDevice s /* = {} */) { - if (is_integral(a.dtype())) { + if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return equal(a, array(-std::numeric_limits::infinity(), a.dtype()), s); @@ -1929,7 +1929,7 @@ array floor_divide( const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - if (is_floating_point(dtype)) { + if (issubdtype(dtype, inexact)) { return floor(divide(a, b, s), s); } @@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) { std::vector divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - if (is_complex(dtype)) { + if (issubdtype(dtype, complexfloating)) { throw std::invalid_argument("[divmod] Complex type not supported."); } auto inputs = @@ -2220,7 +2220,7 @@ array matmul( } // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); - if (!is_floating_point(out_type) || is_complex(out_type)) { + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[matmul] Only real floating point types are supported but " << a.dtype() << " and " << b.dtype() << " were provided which results" @@ -2330,7 +2330,7 @@ array gather( // Promote indices to the same type auto dtype = result_type(indices); - if (!is_integral(dtype)) { + if (issubdtype(dtype, inexact)) { throw std::invalid_argument( "[gather] Got indices with invalid dtype. Indices must be integral."); } @@ -2521,7 +2521,7 @@ array scatter( // Promote indices to the same type auto dtype = result_type(indices); - if (!is_integral(dtype)) { + if (issubdtype(dtype, inexact)) { throw std::invalid_argument( "[scatter] Got indices with invalid dtype. Indices must be integral."); } @@ -2834,7 +2834,7 @@ inline std::vector conv_out_shape( } inline void run_conv_checks(const array& in, const array& wt, int n_dim) { - if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) { + if (!issubdtype(in.dtype(), floating)) { std::ostringstream msg; msg << "[conv] Invalid input array with type " << in.dtype() << "." << " Convolution currently only supports floating point types"; @@ -3062,7 +3062,7 @@ array quantized_matmul( } auto dtype = result_type(x, scales, biases); - if (!is_floating_point(dtype) || is_complex(dtype)) { + if (!issubdtype(dtype, floating)) { std::ostringstream msg; msg << "[quantized_matmul] Only real floating types are supported but " << "the passed types where x.dtype() == " << x.dtype() @@ -3364,7 +3364,7 @@ array addmm( // Type promotion auto out_type = result_type(a, b, c); - if (!is_floating_point(out_type) || is_complex(out_type)) { + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " << c.dtype() << ", " << a.dtype() << " and " << b.dtype() diff --git a/mlx/random.cpp b/mlx/random.cpp index 6e823de39..1490be159 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -97,7 +97,7 @@ array uniform( Dtype dtype /* = float32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { - if (!is_floating_point(dtype) && !is_complex(dtype)) { + if (!issubdtype(dtype, floating)) { throw std::invalid_argument( "Can only generate uniform numbers with real floating point type."); } @@ -179,7 +179,7 @@ array randint( Dtype dtype /* = int32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { - if (!is_integral(dtype)) { + if (issubdtype(dtype, inexact)) { throw std::invalid_argument( "[randint] randint only accepts integer dtypes and bool."); } @@ -192,7 +192,7 @@ array bernoulli( const std::vector& shape, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { - if (!is_floating_point(p.dtype())) { + if (!issubdtype(p.dtype(), floating)) { throw std::invalid_argument( "[bernoulli] bernoulli probability `p` must be a float type."); } @@ -228,7 +228,7 @@ array truncated_normal( // Same as // https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal - if (!is_floating_point(dtype)) { + if (!issubdtype(dtype, floating)) { throw std::invalid_argument( "[trunc_normal] trunc_normal only accepts floating point dtypes."); } diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 2efdf5e33..51c1f7af9 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -578,3 +578,26 @@ class Module(dict): See :func:`train`. """ self.train(False) + + def set_dtype( + self, + dtype: mx.Dtype, + predicate: Optional[Callable[[mx.Dtype], bool]] = lambda x: mx.issubdtype( + x, mx.floating + ), + ): + """Set the dtype of the module's parameters. + + Args: + dtype (Dtype): The new dtype. + predicate (typing.Callable, optional): A predicate to select + parameters to cast. By default, only parameters of type + :attr:`floating` will be updated to avoid casting integer + parameters to the new dtype. + """ + if predicate is None: + + def predicate(_): + return True + + self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x) diff --git a/python/src/array.cpp b/python/src/array.cpp index e16db9568..d8a04d175 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -254,7 +254,7 @@ array array_from_list( std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); - } else if (is_floating_point(dtype)) { + } else if (issubdtype(dtype, inexact)) { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype); @@ -439,6 +439,54 @@ void init_array(nb::module_& m) { m.attr("float32") = nb::cast(float32); m.attr("bfloat16") = nb::cast(bfloat16); m.attr("complex64") = nb::cast(complex64); + nb::class_( + m, + "DtypeCategory", + R"pbdoc( + Type to hold categories of :class:`dtypes `. + + * :attr:`~mlx.core.generic` + + * :ref:`bool_ ` + * :attr:`~mlx.core.number` + + * :attr:`~mlx.core.integer` + + * :attr:`~mlx.core.unsignedinteger` + + * :ref:`uint8 ` + * :ref:`uint16 ` + * :ref:`uint32 ` + * :ref:`uint64 ` + + * :attr:`~mlx.core.signedinteger` + + * :ref:`int8 ` + * :ref:`int32 ` + * :ref:`int64 ` + + * :attr:`~mlx.core.inexact` + + * :attr:`~mlx.core.floating` + + * :ref:`float16 ` + * :ref:`bfloat16 ` + * :ref:`float32 ` + + * :attr:`~mlx.core.complexfloating` + + * :ref:`complex128 ` + + See also :func:`~mlx.core.issubdtype`. + )pbdoc"); + m.attr("complexfloating") = nb::cast(complexfloating); + m.attr("floating") = nb::cast(floating); + m.attr("inexact") = nb::cast(inexact); + m.attr("signedinteger") = nb::cast(signedinteger); + m.attr("unsignedinteger") = nb::cast(unsignedinteger); + m.attr("integer") = nb::cast(integer); + m.attr("number") = nb::cast(number); + m.attr("generic") = nb::cast(generic); nb::class_( m, @@ -700,7 +748,7 @@ void init_array(nb::module_& m) { .def( "__itruediv__", [](array& a, const ScalarOrArray v) -> array& { - if (!is_floating_point(a.dtype())) { + if (!issubdtype(a.dtype(), inexact)) { throw std::invalid_argument( "In place division cannot cast to non-floating point type."); } @@ -852,7 +900,7 @@ void init_array(nb::module_& m) { .def( "__invert__", [](const array& a) { - if (is_floating_point(a.dtype())) { + if (issubdtype(a.dtype(), inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise inversion."); } @@ -866,7 +914,8 @@ void init_array(nb::module_& m) { "__and__", [](const array& a, const ScalarOrArray v) { auto b = to_array(v, a.dtype()); - if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } @@ -881,7 +930,8 @@ void init_array(nb::module_& m) { "__iand__", [](array& a, const ScalarOrArray v) -> array& { auto b = to_array(v, a.dtype()); - if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } @@ -898,7 +948,8 @@ void init_array(nb::module_& m) { "__or__", [](const array& a, const ScalarOrArray v) { auto b = to_array(v, a.dtype()); - if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } @@ -913,7 +964,8 @@ void init_array(nb::module_& m) { "__ior__", [](array& a, const ScalarOrArray v) -> array& { auto b = to_array(v, a.dtype()); - if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) { + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a68eeb9a7..87d483c50 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3684,4 +3684,62 @@ void init_ops(nb::module_& m) { Returns: array or list(array): An array or list of arrays with at least three dimensions. )pbdoc"); + m.def( + "issubdtype", + nb::overload_cast(&issubdtype), + ""_a, + ""_a, + R"pbdoc( + Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype + of another. + + >>> ints = mx.array([1, 2, 3], dtype=mx.int32) + >>> mx.issubdtype(ints.dtype, mx.integer) + True + >>> mx.issubdtype(ints.dtype, mx.floating) + False + + >>> floats = mx.array([1, 2, 3], dtype=mx.float32) + >>> mx.issubdtype(floats.dtype, mx.integer) + False + >>> mx.issubdtype(floats.dtype, mx.floating) + True + + Similar types of different sizes are not subdtypes of each other: + + >>> mx.issubdtype(mx.float64, mx.float32) + False + >>> mx.issubdtype(mx.float32, mx.float64) + False + + but both are subtypes of `floating`: + + >>> mx.issubdtype(mx.float64, mx.floating) + True + >>> mx.issubdtype(mx.float32, mx.floating) + True + + For convenience, dtype-like objects are allowed too: + + >>> mx.issubdtype(mx.float32, mx.inexact) + True + >>> mx.issubdtype(mx.signedinteger, mx.floating) + False + )pbdoc"); + m.def( + "issubdtype", + nb::overload_cast(&issubdtype), + ""_a, + ""_a); + m.def( + "issubdtype", + nb::overload_cast(&issubdtype), + ""_a, + ""_a); + m.def( + "issubdtype", + nb::overload_cast( + &issubdtype), + ""_a, + ""_a); } diff --git a/python/src/utils.h b/python/src/utils.h index 8b52cba12..a320412bd 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -56,7 +56,7 @@ inline array to_array( } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(float32); return array( - nb::cast(*pv), is_floating_point(out_t) ? out_t : float32); + nb::cast(*pv), issubdtype(out_t, floating) ? out_t : float32); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), complex64); } else { diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e3e676c6e..e8abb2227 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1492,6 +1492,29 @@ class TestLayers(mlx_tests.MLXTestCase): "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", ) + def test_set_dtype(self): + def assert_dtype(layer, dtype): + for k, v in tree_flatten(layer.parameters()): + self.assertEqual(v.dtype, dtype, f"dtype mismatch for {k}") + + layer = nn.Linear(input_dims=4, output_dims=8, bias=True) + assert_dtype(layer, mx.float32) + + layer.set_dtype(mx.bfloat16) + assert_dtype(layer, mx.bfloat16) + + layer.set_dtype(mx.float32, lambda x: False) + assert_dtype(layer, mx.bfloat16) + + layer.set_dtype(mx.int32, lambda x: True) + assert_dtype(layer, mx.int32) + + layer.set_dtype(mx.int64, predicate=None) + assert_dtype(layer, mx.int64) + + layer.set_dtype(mx.int16, lambda x: mx.issubdtype(x, mx.integer)) + assert_dtype(layer, mx.int16) + def test_rnn(self): layer = nn.RNN(input_size=5, hidden_size=12, bias=True) inp = mx.random.normal((2, 25, 5)) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5e599c01d..5fef45c64 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2026,6 +2026,40 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx_res.ndim, np_res.ndim) self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) + def test_issubdtype(self): + self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact)) + + cats = [ + "complexfloating", + "floating", + "inexact", + "signedinteger", + "unsignedinteger", + "integer", + "number", + "generic", + "bool_", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "complex64", + ] + + for a in cats: + for b in cats: + self.assertEqual( + mx.issubdtype(getattr(mx, a), getattr(mx, b)), + np.issubdtype(getattr(np, a), getattr(np, b)), + f"mx and np don't aggree on {a}, {b}", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 204f1ffd3..fd00d339d 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2848,6 +2848,192 @@ TEST_CASE("test diag") { CHECK(array_equal(out, array({3, 7}, {2})).item()); } +TEST_CASE("test issubdtype") { + const auto cats = { + complexfloating, + floating, + inexact, + signedinteger, + unsignedinteger, + integer, + number, + generic}; + const auto types = { + bool_, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float16, + float32, + bfloat16, + complex64}; + for (const auto& type : types) { + CHECK(issubdtype(type, type)); + CHECK(issubdtype(type, generic)); + switch (kindof(type)) { + case Dtype::Kind::b: + CHECK_FALSE(issubdtype(type, complexfloating)); + CHECK_FALSE(issubdtype(type, floating)); + CHECK_FALSE(issubdtype(type, inexact)); + CHECK_FALSE(issubdtype(type, signedinteger)); + CHECK_FALSE(issubdtype(type, unsignedinteger)); + CHECK_FALSE(issubdtype(type, integer)); + CHECK_FALSE(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + case Dtype::Kind::u: + CHECK_FALSE(issubdtype(type, complexfloating)); + CHECK_FALSE(issubdtype(type, floating)); + CHECK_FALSE(issubdtype(type, inexact)); + CHECK_FALSE(issubdtype(type, signedinteger)); + CHECK(issubdtype(type, unsignedinteger)); + CHECK(issubdtype(type, integer)); + CHECK(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + case Dtype::Kind::i: + CHECK_FALSE(issubdtype(type, complexfloating)); + CHECK_FALSE(issubdtype(type, floating)); + CHECK_FALSE(issubdtype(type, inexact)); + CHECK(issubdtype(type, signedinteger)); + CHECK_FALSE(issubdtype(type, unsignedinteger)); + CHECK(issubdtype(type, integer)); + CHECK(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + case Dtype::Kind::f: + CHECK_FALSE(issubdtype(type, complexfloating)); + CHECK(issubdtype(type, floating)); + CHECK(issubdtype(type, inexact)); + CHECK_FALSE(issubdtype(type, signedinteger)); + CHECK_FALSE(issubdtype(type, unsignedinteger)); + CHECK_FALSE(issubdtype(type, integer)); + CHECK(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + case Dtype::Kind::c: + CHECK(issubdtype(type, complexfloating)); + CHECK_FALSE(issubdtype(type, floating)); + CHECK(issubdtype(type, inexact)); + CHECK_FALSE(issubdtype(type, signedinteger)); + CHECK_FALSE(issubdtype(type, unsignedinteger)); + CHECK_FALSE(issubdtype(type, integer)); + CHECK(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + case Dtype::Kind::V: + CHECK_FALSE(issubdtype(type, complexfloating)); + CHECK(issubdtype(type, floating)); + CHECK(issubdtype(type, inexact)); + CHECK_FALSE(issubdtype(type, signedinteger)); + CHECK_FALSE(issubdtype(type, unsignedinteger)); + CHECK_FALSE(issubdtype(type, integer)); + CHECK(issubdtype(type, number)); + CHECK(issubdtype(type, generic)); + break; + } + } + + for (const auto& type : types) { + CHECK(issubdtype(type, type)); + CHECK(issubdtype(type, generic)); + for (auto type1 : types) { + CHECK_EQ(issubdtype(type, type1), type == type1); + } + } + + for (const auto& cat : cats) { + CHECK(issubdtype(cat, cat)); + switch (cat) { + case Dtype::Category::complexfloating: + CHECK(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK_FALSE(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::floating: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK(issubdtype(cat, floating)); + CHECK(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK_FALSE(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::inexact: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK_FALSE(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::signedinteger: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK_FALSE(issubdtype(cat, inexact)); + CHECK(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::unsignedinteger: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK_FALSE(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK(issubdtype(cat, unsignedinteger)); + CHECK(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::integer: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK_FALSE(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::number: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK_FALSE(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK_FALSE(issubdtype(cat, integer)); + CHECK(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + case Dtype::Category::generic: + CHECK_FALSE(issubdtype(cat, complexfloating)); + CHECK_FALSE(issubdtype(cat, floating)); + CHECK_FALSE(issubdtype(cat, inexact)); + CHECK_FALSE(issubdtype(cat, signedinteger)); + CHECK_FALSE(issubdtype(cat, unsignedinteger)); + CHECK_FALSE(issubdtype(cat, integer)); + CHECK_FALSE(issubdtype(cat, number)); + CHECK(issubdtype(cat, generic)); + break; + } + } +} + TEST_CASE("test atleast_1d") { auto x = array(1); auto out = atleast_1d(x);