mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
fix build w/ flatten (#195)
This commit is contained in:
parent
52e1589a52
commit
90d04072b7
31
mlx/ops.cpp
31
mlx/ops.cpp
@ -283,21 +283,36 @@ array flatten(
|
|||||||
int end_axis /* = -1 */,
|
int end_axis /* = -1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto ndim = static_cast<int>(a.ndim());
|
auto ndim = static_cast<int>(a.ndim());
|
||||||
start_axis += (start_axis < 0 ? ndim : 0);
|
auto start_ax = start_axis + (start_axis < 0 ? ndim : 0);
|
||||||
end_axis += (end_axis < 0 ? ndim + 1 : 0);
|
auto end_ax = end_axis + (end_axis < 0 ? ndim : 0);
|
||||||
start_axis = std::max(0, start_axis);
|
start_ax = std::max(0, start_ax);
|
||||||
end_axis = std::min(ndim, end_axis);
|
end_ax = std::min(ndim - 1, end_ax);
|
||||||
if (end_axis < start_axis) {
|
if (a.ndim() == 0) {
|
||||||
|
return reshape(a, {1}, s);
|
||||||
|
}
|
||||||
|
if (end_ax < start_ax) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[flatten] start_axis must be less than or equal to end_axis");
|
"[flatten] start_axis must be less than or equal to end_axis");
|
||||||
}
|
}
|
||||||
if (start_axis == end_axis and a.ndim() != 0) {
|
if (start_ax >= ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[flatten] Invalid start_axis " << start_axis << " for array with "
|
||||||
|
<< ndim << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (end_ax < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[flatten] Invalid end_axis " << end_axis << " for array with "
|
||||||
|
<< ndim << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (start_ax == end_ax) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis);
|
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax);
|
||||||
new_shape.push_back(-1);
|
new_shape.push_back(-1);
|
||||||
new_shape.insert(
|
new_shape.insert(
|
||||||
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end());
|
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
|
||||||
return reshape(a, new_shape, s);
|
return reshape(a, new_shape, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -906,7 +906,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
# Check slice assign with negative indices works
|
# Check slice assign with negative indices works
|
||||||
a = mx.zeros((5, 5), mx.int32)
|
a = mx.zeros((5, 5), mx.int32)
|
||||||
a[2:-2, 2:-2] = 4
|
a[2:-2, 2:-2] = 4
|
||||||
self.assertEquals(a[2, 2].item(), 4)
|
self.assertEqual(a[2, 2].item(), 4)
|
||||||
|
|
||||||
def test_slice_negative_step(self):
|
def test_slice_negative_step(self):
|
||||||
a_np = np.arange(20)
|
a_np = np.arange(20)
|
||||||
|
@ -73,6 +73,12 @@ TEST_CASE("test flatten") {
|
|||||||
// Check start > end throws
|
// Check start > end throws
|
||||||
CHECK_THROWS(flatten(x, 2, 1));
|
CHECK_THROWS(flatten(x, 2, 1));
|
||||||
|
|
||||||
|
// Check start >= ndim throws
|
||||||
|
CHECK_THROWS(flatten(x, 5, 6));
|
||||||
|
|
||||||
|
// Check end < 0 throws
|
||||||
|
CHECK_THROWS(flatten(x, -5, -4));
|
||||||
|
|
||||||
// Check scalar flattens to 1D
|
// Check scalar flattens to 1D
|
||||||
x = array(1);
|
x = array(1);
|
||||||
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
|
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
|
||||||
|
Loading…
Reference in New Issue
Block a user