add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)

* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Daniel Strobusch
2024-03-25 20:32:59 +01:00
committed by GitHub
parent bfb5bad4f0
commit 479051ce1c
26 changed files with 538 additions and 97 deletions

View File

@@ -301,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
@@ -355,7 +355,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(

View File

@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"

View File

@@ -22,7 +22,7 @@ namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (is_unsigned(in.dtype())) {
if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
@@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
@@ -362,7 +362,7 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
@@ -388,7 +388,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
@@ -410,7 +410,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
@@ -597,7 +597,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
@@ -608,7 +608,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
@@ -630,7 +630,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
@@ -642,7 +642,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
@@ -850,7 +850,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
@@ -862,7 +862,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(

View File

@@ -222,7 +222,7 @@ void scan_dispatch(
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (is_floating_point(input.dtype()))
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
@@ -232,7 +232,7 @@ void scan_dispatch(
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (is_floating_point(input.dtype()))
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);

View File

@@ -488,7 +488,7 @@ void steel_matmul(
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
@@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}

View File

@@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (not is_integral(in.dtype())) {
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
} else {
// No-op integer types

View File

@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) {
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[softmax] Does not support non-floating point types.");
}

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <sstream>
@@ -12,6 +12,7 @@ namespace mlx::core {
namespace {
constexpr int num_types = 13;
constexpr int num_cats = 8;
constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::b, // bool_,
@@ -49,6 +50,35 @@ constexpr Dtype type_rules[num_types][num_types] = {
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
};
constexpr bool subcategory_to_category[num_cats][num_cats] = {
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
{true, false, true, false, false, false, true, true}, // complexfloating
{false, true, true, false, false, false, true, true}, // floating
{false, false, true, false, false, false, true, true}, // inexact
{false, false, false, true, false, true, true, true}, // signedinteger
{false, false, false, false, true, true, true, true}, // unsignedinteger
{false, false, false, false, false, true, true, true}, // integer
{false, false, false, false, false, false, true, true}, // number
{false, false, false, false, false, false, false, true}, // generic
};
constexpr Dtype::Category type_to_category[num_types] = {
Dtype::Category::generic, // bool_,
Dtype::Category::unsignedinteger, // uint8,
Dtype::Category::unsignedinteger, // uint16,
Dtype::Category::unsignedinteger, // uint32,
Dtype::Category::unsignedinteger, // uint64,
Dtype::Category::signedinteger, // int8,
Dtype::Category::signedinteger, // int16,
Dtype::Category::signedinteger, // int32,
Dtype::Category::signedinteger, // int64,
Dtype::Category::floating, // float16,
Dtype::Category::floating, // float32,
Dtype::Category::floating, // bfloat16,
Dtype::Category::complexfloating, // complex64,
};
// clang-format on
inline bool is_big_endian() {
@@ -141,6 +171,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
return complex64;
}
bool issubdtype(const Dtype& a, const Dtype& b) {
return a == b;
}
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
return false;
}
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
}
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
return subcategory_to_category[static_cast<uint32_t>(a)]
[static_cast<uint32_t>(b)];
}
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -38,6 +38,17 @@ struct Dtype {
V, /* void - used for brain float */
};
enum class Category {
complexfloating,
floating,
inexact,
signedinteger,
unsignedinteger,
integer,
number,
generic
};
Val val;
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
@@ -63,6 +74,22 @@ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
inline constexpr Dtype::Category complexfloating =
Dtype::Category::complexfloating;
inline constexpr Dtype::Category floating = Dtype::Category::floating;
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
inline constexpr Dtype::Category unsignedinteger =
Dtype::Category::unsignedinteger;
inline constexpr Dtype::Category integer = Dtype::Category::integer;
inline constexpr Dtype::Category number = Dtype::Category::number;
inline constexpr Dtype::Category generic = Dtype::Category::generic;
bool issubdtype(const Dtype& a, const Dtype& b);
bool issubdtype(const Dtype::Category& a, const Dtype& b);
bool issubdtype(const Dtype& a, const Dtype::Category& b);
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) {
@@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
Dtype::Kind kindof(const Dtype& t);
inline bool is_unsigned(const Dtype& t) {
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
}
inline bool is_floating_point(const Dtype& t) {
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
kindof(t) == Dtype::Kind::c;
}
inline bool is_complex(const Dtype& t) {
return kindof(t) == Dtype::Kind::c;
}
inline bool is_integral(const Dtype& t) {
return !(is_floating_point(t));
}
template <typename T>
struct TypeToDtype {
operator Dtype();

View File

@@ -64,7 +64,7 @@ array rms_norm(
throw std::invalid_argument(msg.str());
}
auto out_type = result_type(x, weight);
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
@@ -128,7 +128,7 @@ array layer_norm(
? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight))
: x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[layer_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
@@ -319,7 +319,7 @@ array scaled_dot_product_attention(
}
auto final_type = result_type(queries, keys, values);
if (!is_floating_point(final_type) || is_complex(final_type)) {
if (!issubdtype(final_type, floating)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type "
<< final_type << ".";

View File

@@ -11,7 +11,7 @@
namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
inline array l2_norm(
@@ -19,7 +19,7 @@ inline array l2_norm(
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (is_complex(a.dtype())) {
if (issubdtype(a.dtype(), complexfloating)) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else {
return sqrt(sum(square(a, s), axis, keepdims, s), s);

View File

@@ -47,7 +47,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
}
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}
} // namespace
@@ -1140,7 +1140,7 @@ array array_equal(
return array(false);
} else {
auto dtype = promote_types(a.dtype(), b.dtype());
equal_nan &= is_floating_point(dtype);
equal_nan &= issubdtype(dtype, inexact);
return all(
array(
a.shape(),
@@ -1153,7 +1153,7 @@ array array_equal(
}
array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return not_equal(a, a, s);
@@ -1164,14 +1164,14 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
@@ -1929,7 +1929,7 @@ array floor_divide(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_floating_point(dtype)) {
if (issubdtype(dtype, inexact)) {
return floor(divide(a, b, s), s);
}
@@ -1957,7 +1957,7 @@ array operator%(const array& a, const array& b) {
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
if (is_complex(dtype)) {
if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs =
@@ -2220,7 +2220,7 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
@@ -2330,7 +2330,7 @@ array gather(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[gather] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2521,7 +2521,7 @@ array scatter(
// Promote indices to the same type
auto dtype = result_type(indices);
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[scatter] Got indices with invalid dtype. Indices must be integral.");
}
@@ -2834,7 +2834,7 @@ inline std::vector<int> conv_out_shape(
}
inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
if (!is_floating_point(in.dtype()) && kindof(in.dtype()) != Dtype::Kind::c) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
<< " Convolution currently only supports floating point types";
@@ -3062,7 +3062,7 @@ array quantized_matmul(
}
auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) {
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype()
@@ -3364,7 +3364,7 @@ array addmm(
// Type promotion
auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) {
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()

View File

@@ -97,7 +97,7 @@ array uniform(
Dtype dtype /* = float32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_floating_point(dtype) && !is_complex(dtype)) {
if (!issubdtype(dtype, floating)) {
throw std::invalid_argument(
"Can only generate uniform numbers with real floating point type.");
}
@@ -179,7 +179,7 @@ array randint(
Dtype dtype /* = int32 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_integral(dtype)) {
if (issubdtype(dtype, inexact)) {
throw std::invalid_argument(
"[randint] randint only accepts integer dtypes and bool.");
}
@@ -192,7 +192,7 @@ array bernoulli(
const std::vector<int>& shape,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
if (!is_floating_point(p.dtype())) {
if (!issubdtype(p.dtype(), floating)) {
throw std::invalid_argument(
"[bernoulli] bernoulli probability `p` must be a float type.");
}
@@ -228,7 +228,7 @@ array truncated_normal(
// Same as
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
if (!is_floating_point(dtype)) {
if (!issubdtype(dtype, floating)) {
throw std::invalid_argument(
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
}