From 608bd43604b05000a4401a9adffa383ca68618f6 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 5 Jan 2024 19:10:13 -0800 Subject: [PATCH] Move the matmul type check in the op (#384) --- mlx/ops.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8ec7787f9..2d9410d94 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1948,6 +1948,13 @@ array matmul( } // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); + if (!is_floating_point(out_type) || is_complex(out_type)) { + std::ostringstream msg; + msg << "[matmul] Only real floating point types are supported but " + << a.dtype() << " and " << b.dtype() << " were provided which results" + << " in " << out_type << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } if (a.dtype() != out_type) { a = astype(a, out_type, s); }