mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use std::optional for mask_arr arg (#2763)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.8) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.8) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.9) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.8) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.8) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.9) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
This commit is contained in:
35
mlx/fast.cpp
35
mlx/fast.cpp
@@ -578,7 +578,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::string& mask_mode /* = "" */,
|
||||
const std::vector<array>& mask_arrs /* = {} */,
|
||||
std::optional<array> mask_arr /* = {} */,
|
||||
const std::optional<array>& sinks /* = {} */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
@@ -606,32 +606,22 @@ array scaled_dot_product_attention(
|
||||
has_mask = true;
|
||||
do_causal = true;
|
||||
|
||||
if (!mask_arrs.empty()) {
|
||||
if (mask_arr) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'casusal'. No array masks supported.";
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arr for mask_mode "
|
||||
<< "'casusal'. No array mask should be passed.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_mode == "array" || (mask_mode == "" && !mask_arrs.empty())) {
|
||||
if (mask_arrs.size() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Invalid mask_arrs for mask_mode "
|
||||
<< "'" << mask_mode << "'. Only 1 mask array is supported, got "
|
||||
<< mask_arrs.size() << "arrays.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
} else if (mask_arr) {
|
||||
has_mask = true;
|
||||
has_arr_mask = true;
|
||||
has_bool_mask = mask_arrs[0].dtype() == bool_;
|
||||
has_bool_mask = mask_arr->dtype() == bool_;
|
||||
}
|
||||
|
||||
if (has_arr_mask && (mask_arrs[0]).ndim() > 4) {
|
||||
if (has_arr_mask && mask_arr->ndim() > 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||
<< mask_arrs[0].shape() << " expected to have at most rank 4.";
|
||||
<< mask_arr->shape() << " expected to have at most rank 4.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -764,20 +754,19 @@ array scaled_dot_product_attention(
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
auto mask_arr = mask_arrs[0];
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
has_bool_mask = mask_arr->dtype() == bool_;
|
||||
if (promote_types(mask_arr->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else if (!has_bool_mask) {
|
||||
mask_arr = astype(mask_arr, final_type, stream);
|
||||
mask_arr = astype(*mask_arr, final_type, stream);
|
||||
}
|
||||
// Broadcast mask
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
inputs.push_back(broadcast_to(*mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (has_sinks) {
|
||||
if (promote_types(sinks->dtype(), final_type) != final_type) {
|
||||
|
||||
@@ -49,7 +49,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::string& mask_mode = "",
|
||||
const std::vector<array>& mask_arrs = {},
|
||||
std::optional<array> mask_arr = {},
|
||||
const std::optional<array>& sinks = {},
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
||||
@@ -213,11 +213,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, mask_str, {}, sinks, s);
|
||||
queries, keys, values, scale, mask_str, std::nullopt, sinks, s);
|
||||
} else {
|
||||
auto mask_arr = std::get<mx::array>(mask);
|
||||
return mx::fast::scaled_dot_product_attention(
|
||||
queries, keys, values, scale, "", {mask_arr}, sinks, s);
|
||||
queries, keys, values, scale, "", mask_arr, sinks, s);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user