From f6c0499b8d4a6f01c501edabafe6aa4407bc2f3a Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Tue, 4 Feb 2025 15:06:17 +0100 Subject: [PATCH] Resolved ambiguity in mlx::core::take_along_axis (#1822) * Resolved ambiguity in mlx::core::take_along_axis Detected by GCC 10 on riscv64-linux-gnu. * Formatted * Removed superfluous parentheses in random_tests.cpp --- mlx/ops.cpp | 3 ++- tests/random_tests.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 51cfd8783..8853d8351 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3081,7 +3081,8 @@ array take_along_axis( axis = axis < 0 ? a.ndim() + axis : axis; // Broadcast indices and input ignoring the take axis - auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s); + auto inputs = + broadcast_arrays({a, indices}, std::vector{axis - int(a.ndim())}, s); auto out_shape = inputs[1].shape(); return array( diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index fe5f6b317..49f1f300b 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -556,7 +556,7 @@ TEST_CASE("test random bernoulli") { p = array({0.1, 0.2, 0.3}); // Ask for the wrong shape => throws - CHECK_THROWS_AS(random::bernoulli(p, Shape({2})), std::invalid_argument); + CHECK_THROWS_AS(random::bernoulli(p, Shape{2}), std::invalid_argument); // Check wrong key type or shape auto key = array({0, 0}, {1, 2});