mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Resolved ambiguity in mlx::core::take_along_axis (#1822)
* Resolved ambiguity in mlx::core::take_along_axis Detected by GCC 10 on riscv64-linux-gnu. * Formatted * Removed superfluous parentheses in random_tests.cpp
This commit is contained in:
parent
1156c84e86
commit
f6c0499b8d
@ -3081,7 +3081,8 @@ array take_along_axis(
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
// Broadcast indices and input ignoring the take axis
|
||||
auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s);
|
||||
auto inputs =
|
||||
broadcast_arrays({a, indices}, std::vector<int>{axis - int(a.ndim())}, s);
|
||||
|
||||
auto out_shape = inputs[1].shape();
|
||||
return array(
|
||||
|
@ -556,7 +556,7 @@ TEST_CASE("test random bernoulli") {
|
||||
|
||||
p = array({0.1, 0.2, 0.3});
|
||||
// Ask for the wrong shape => throws
|
||||
CHECK_THROWS_AS(random::bernoulli(p, Shape({2})), std::invalid_argument);
|
||||
CHECK_THROWS_AS(random::bernoulli(p, Shape{2}), std::invalid_argument);
|
||||
|
||||
// Check wrong key type or shape
|
||||
auto key = array({0, 0}, {1, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user