fix build w/ flatten (#195)

This commit is contained in:
Awni Hannun 2023-12-17 11:58:45 -08:00 committed by GitHub
parent 52e1589a52
commit 90d04072b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 9 deletions

View File

@ -283,21 +283,36 @@ array flatten(
int end_axis /* = -1 */, int end_axis /* = -1 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto ndim = static_cast<int>(a.ndim()); auto ndim = static_cast<int>(a.ndim());
start_axis += (start_axis < 0 ? ndim : 0); auto start_ax = start_axis + (start_axis < 0 ? ndim : 0);
end_axis += (end_axis < 0 ? ndim + 1 : 0); auto end_ax = end_axis + (end_axis < 0 ? ndim : 0);
start_axis = std::max(0, start_axis); start_ax = std::max(0, start_ax);
end_axis = std::min(ndim, end_axis); end_ax = std::min(ndim - 1, end_ax);
if (end_axis < start_axis) { if (a.ndim() == 0) {
return reshape(a, {1}, s);
}
if (end_ax < start_ax) {
throw std::invalid_argument( throw std::invalid_argument(
"[flatten] start_axis must be less than or equal to end_axis"); "[flatten] start_axis must be less than or equal to end_axis");
} }
if (start_axis == end_axis and a.ndim() != 0) { if (start_ax >= ndim) {
std::ostringstream msg;
msg << "[flatten] Invalid start_axis " << start_axis << " for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (end_ax < 0) {
std::ostringstream msg;
msg << "[flatten] Invalid end_axis " << end_axis << " for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (start_ax == end_ax) {
return a; return a;
} }
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis); std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax);
new_shape.push_back(-1); new_shape.push_back(-1);
new_shape.insert( new_shape.insert(
new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end()); new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
return reshape(a, new_shape, s); return reshape(a, new_shape, s);
} }

View File

@ -906,7 +906,7 @@ class TestArray(mlx_tests.MLXTestCase):
# Check slice assign with negative indices works # Check slice assign with negative indices works
a = mx.zeros((5, 5), mx.int32) a = mx.zeros((5, 5), mx.int32)
a[2:-2, 2:-2] = 4 a[2:-2, 2:-2] = 4
self.assertEquals(a[2, 2].item(), 4) self.assertEqual(a[2, 2].item(), 4)
def test_slice_negative_step(self): def test_slice_negative_step(self):
a_np = np.arange(20) a_np = np.arange(20)

View File

@ -73,6 +73,12 @@ TEST_CASE("test flatten") {
// Check start > end throws // Check start > end throws
CHECK_THROWS(flatten(x, 2, 1)); CHECK_THROWS(flatten(x, 2, 1));
// Check start >= ndim throws
CHECK_THROWS(flatten(x, 5, 6));
// Check end < 0 throws
CHECK_THROWS(flatten(x, -5, -4));
// Check scalar flattens to 1D // Check scalar flattens to 1D
x = array(1); x = array(1);
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1})); CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));