From 49c34c41611ce82a75e31ea7e941a436fdb39a18 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 14:25:18 -0800 Subject: [PATCH] check mask type (#1721) --- mlx/fast.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 58800fb70b..2c93d68618 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -586,6 +586,13 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } + if (mask && (*mask).dtype() != final_type) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Mask should match output type. " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s);