Move the matmul type check in the op (#384)

This commit is contained in:
Angelos Katharopoulos 2024-01-05 19:10:13 -08:00 committed by GitHub
parent 4c48f6460d
commit 608bd43604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1948,6 +1948,13 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
<< " in " << out_type << ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
if (a.dtype() != out_type) {
a = astype(a, out_type, s);
}