mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Einsum error msg improvement (#2690)
* Improved error message for Einsum * Modifications via pre-commit * format * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
8f8af61a37
commit
0cfeeb60ca
@@ -671,7 +671,8 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
}
|
}
|
||||||
int max_ellipsis_length = 0;
|
int max_ellipsis_length = 0;
|
||||||
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
|
auto check_letters_and_expand_ellipsis = [&](auto& subscript,
|
||||||
const array* operand) {
|
const array* operand,
|
||||||
|
int operand_idx) {
|
||||||
bool have_ellipsis = false;
|
bool have_ellipsis = false;
|
||||||
int cnt_before = 0, cnt_after = 0;
|
int cnt_before = 0, cnt_after = 0;
|
||||||
for (int i = 0; i < subscript.size(); i++) {
|
for (int i = 0; i < subscript.size(); i++) {
|
||||||
@@ -708,10 +709,21 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
int ellipsis_length;
|
int ellipsis_length;
|
||||||
if (operand != nullptr) {
|
if (operand != nullptr) {
|
||||||
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
|
ellipsis_length = operand->ndim() - cnt_before - cnt_after;
|
||||||
|
if (ellipsis_length < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << fn_name << "] Operand " << operand_idx << " with shape "
|
||||||
|
<< operand->shape()
|
||||||
|
<< " has insufficient dimensions for subscript '" << subscript
|
||||||
|
<< "'. The ellipsis requires at least "
|
||||||
|
<< (cnt_before + cnt_after) << " dimensions but the operand has "
|
||||||
|
<< operand->ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
|
max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
|
||||||
} else {
|
} else {
|
||||||
ellipsis_length = max_ellipsis_length;
|
ellipsis_length = max_ellipsis_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
subscript.replace(
|
subscript.replace(
|
||||||
subscript.begin() + cnt_before,
|
subscript.begin() + cnt_before,
|
||||||
subscript.begin() + cnt_before + 3,
|
subscript.begin() + cnt_before + 3,
|
||||||
@@ -721,9 +733,9 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (int i = 0; i < operands.size(); i++) {
|
for (int i = 0; i < operands.size(); i++) {
|
||||||
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i]);
|
check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i);
|
||||||
}
|
}
|
||||||
check_letters_and_expand_ellipsis(out_subscript, nullptr);
|
check_letters_and_expand_ellipsis(out_subscript, nullptr, -1);
|
||||||
|
|
||||||
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
CharSet out_set(out_subscript.begin(), out_subscript.end());
|
||||||
if (out_set.size() != out_subscript.size()) {
|
if (out_set.size() != out_subscript.size()) {
|
||||||
|
|||||||
Reference in New Issue
Block a user