Fixes for large arrays with a few ops (#1299)

* fixes for large arrays with a few ops

* fix bug

* fix all of copy
This commit is contained in:
Awni Hannun
2024-07-30 17:18:39 -07:00
committed by GitHub
parent c52d1600f0
commit 40b6d67333
21 changed files with 273 additions and 202 deletions

View File

@@ -273,7 +273,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
// Check for the number of indices passed
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -585,7 +585,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@@ -840,7 +840,7 @@ auto mlx_slice_update(
// Dimension check
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}