mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Move the matmul type check in the op (#384)
This commit is contained in:
parent
4c48f6460d
commit
608bd43604
@ -1948,6 +1948,13 @@ array matmul(
|
||||
}
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
<< " in " << out_type << ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.dtype() != out_type) {
|
||||
a = astype(a, out_type, s);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user