diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8ec7787f9..2d9410d94 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1948,6 +1948,13 @@ array matmul( } // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); + if (!is_floating_point(out_type) || is_complex(out_type)) { + std::ostringstream msg; + msg << "[matmul] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() << " were provided which results" + << " in " << out_type << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } if (a.dtype() != out_type) { a = astype(a, out_type, s); }