Resolved ambiguity in mlx::core::take_along_axis (#1822)

* Resolved ambiguity in mlx::core::take_along_axis

Detected by GCC 10 on riscv64-linux-gnu.

* Formatted

* Removed superfluous parentheses in random_tests.cpp
This commit is contained in:
Jesper Stemann Andersen 2025-02-04 15:06:17 +01:00 committed by GitHub
parent 1156c84e86
commit f6c0499b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View File

@ -3081,7 +3081,8 @@ array take_along_axis(
axis = axis < 0 ? a.ndim() + axis : axis;
// Broadcast indices and input ignoring the take axis
auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s);
auto inputs =
broadcast_arrays({a, indices}, std::vector<int>{axis - int(a.ndim())}, s);
auto out_shape = inputs[1].shape();
return array(

View File

@ -556,7 +556,7 @@ TEST_CASE("test random bernoulli") {
p = array({0.1, 0.2, 0.3});
// Ask for the wrong shape => throws
CHECK_THROWS_AS(random::bernoulli(p, Shape({2})), std::invalid_argument);
CHECK_THROWS_AS(random::bernoulli(p, Shape{2}), std::invalid_argument);
// Check wrong key type or shape
auto key = array({0, 0}, {1, 2});