diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 51cfd8783..8853d8351 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3081,7 +3081,8 @@ array take_along_axis( axis = axis < 0 ? a.ndim() + axis : axis; // Broadcast indices and input ignoring the take axis - auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s); + auto inputs = + broadcast_arrays({a, indices}, std::vector{axis - int(a.ndim())}, s); auto out_shape = inputs[1].shape(); return array( diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index fe5f6b317..49f1f300b 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -556,7 +556,7 @@ TEST_CASE("test random bernoulli") { p = array({0.1, 0.2, 0.3}); // Ask for the wrong shape => throws - CHECK_THROWS_AS(random::bernoulli(p, Shape({2})), std::invalid_argument); + CHECK_THROWS_AS(random::bernoulli(p, Shape{2}), std::invalid_argument); // Check wrong key type or shape auto key = array({0, 0}, {1, 2});