Move the matmul type check in the op (#384)

This commit is contained in:
Angelos Katharopoulos
2024-01-05 19:10:13 -08:00
committed by GitHub
parent 4c48f6460d
commit 608bd43604

View File

@@ -1948,6 +1948,13 @@ array matmul(
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype() << " were provided which results"
<< " in " << out_type << ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
if (a.dtype() != out_type) {
a = astype(a, out_type, s);
}