diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 82c419264..c7c572a08 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -578,7 +578,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::string& mask_mode /* = "" */, - const std::vector& mask_arrs /* = {} */, + std::optional mask_arr /* = {} */, const std::optional& sinks /* = {} */, StreamOrDevice s /* = {}*/) { for (const auto& tensor : {queries, keys, values}) { @@ -606,32 +606,22 @@ array scaled_dot_product_attention( has_mask = true; do_causal = true; - if (!mask_arrs.empty()) { + if (mask_arr) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " - << "'casusal'. No array masks supported."; + msg << "[scaled_dot_product_attention] Invalid mask_arr for mask_mode " + << "'casusal'. No array mask should be passed."; throw std::invalid_argument(msg.str()); } - } - - if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) { - if (mask_arrs.size() != 1) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode " - << "'" << mask_mode << "'. Only 1 mask array is supported, got " - << mask_arrs.size() << "arrays."; - throw std::invalid_argument(msg.str()); - } - + } else if (mask_arr) { has_mask = true; has_arr_mask = true; - has_bool_mask = mask_arrs[0].dtype() == bool_; + has_bool_mask = mask_arr->dtype() == bool_; } - if (has_arr_mask && (mask_arrs[0]).ndim() > 4) { + if (has_arr_mask && mask_arr->ndim() > 4) { std::ostringstream msg; msg << "[scaled_dot_product_attention] the mask with shape " - << mask_arrs[0].shape() << " expected to have at most rank 4."; + << mask_arr->shape() << " expected to have at most rank 4."; throw std::invalid_argument(msg.str()); } @@ -764,20 +754,19 @@ array scaled_dot_product_attention( std::vector inputs = {q, k, v}; if (has_arr_mask) { // Check type - auto mask_arr = mask_arrs[0]; - has_bool_mask = mask_arr.dtype() == bool_; - if (promote_types(mask_arr.dtype(), final_type) != final_type) { + has_bool_mask = mask_arr->dtype() == bool_; + if (promote_types(mask_arr->dtype(), final_type) != final_type) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Mask type must promote to output type " << final_type << "."; throw std::invalid_argument(msg.str()); } else if (!has_bool_mask) { - mask_arr = astype(mask_arr, final_type, stream); + mask_arr = astype(*mask_arr, final_type, stream); } // Broadcast mask auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); - inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); + inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream)); } if (has_sinks) { if (promote_types(sinks->dtype(), final_type) != final_type) { diff --git a/mlx/fast.h b/mlx/fast.h index 3cbff60e4..0884bac7f 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -49,7 +49,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::string& mask_mode = "", - const std::vector& mask_arrs = {}, + std::optional mask_arr = {}, const std::optional& sinks = {}, StreamOrDevice s = {}); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 0ed1aa698..97dd632c5 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -213,11 +213,11 @@ void init_fast(nb::module_& parent_module) { throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, sinks, s); + queries, keys, values, scale, mask_str, std::nullopt, sinks, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, sinks, s); + queries, keys, values, scale, "", mask_arr, sinks, s); } } else {