mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix build w/ flatten (#195)
This commit is contained in:
		
							
								
								
									
										31
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								mlx/ops.cpp
									
									
									
									
									
								
							| @@ -283,21 +283,36 @@ array flatten( | ||||
|     int end_axis /* = -1 */, | ||||
|     StreamOrDevice s /* = {} */) { | ||||
|   auto ndim = static_cast<int>(a.ndim()); | ||||
|   start_axis += (start_axis < 0 ? ndim : 0); | ||||
|   end_axis += (end_axis < 0 ? ndim + 1 : 0); | ||||
|   start_axis = std::max(0, start_axis); | ||||
|   end_axis = std::min(ndim, end_axis); | ||||
|   if (end_axis < start_axis) { | ||||
|   auto start_ax = start_axis + (start_axis < 0 ? ndim : 0); | ||||
|   auto end_ax = end_axis + (end_axis < 0 ? ndim : 0); | ||||
|   start_ax = std::max(0, start_ax); | ||||
|   end_ax = std::min(ndim - 1, end_ax); | ||||
|   if (a.ndim() == 0) { | ||||
|     return reshape(a, {1}, s); | ||||
|   } | ||||
|   if (end_ax < start_ax) { | ||||
|     throw std::invalid_argument( | ||||
|         "[flatten] start_axis must be less than or equal to end_axis"); | ||||
|   } | ||||
|   if (start_axis == end_axis and a.ndim() != 0) { | ||||
|   if (start_ax >= ndim) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[flatten] Invalid start_axis " << start_axis << " for array with " | ||||
|         << ndim << " dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (end_ax < 0) { | ||||
|     std::ostringstream msg; | ||||
|     msg << "[flatten] Invalid end_axis " << end_axis << " for array with " | ||||
|         << ndim << " dimensions."; | ||||
|     throw std::invalid_argument(msg.str()); | ||||
|   } | ||||
|   if (start_ax == end_ax) { | ||||
|     return a; | ||||
|   } | ||||
|   std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_axis); | ||||
|   std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax); | ||||
|   new_shape.push_back(-1); | ||||
|   new_shape.insert( | ||||
|       new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end()); | ||||
|       new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); | ||||
|   return reshape(a, new_shape, s); | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -906,7 +906,7 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         # Check slice assign with negative indices works | ||||
|         a = mx.zeros((5, 5), mx.int32) | ||||
|         a[2:-2, 2:-2] = 4 | ||||
|         self.assertEquals(a[2, 2].item(), 4) | ||||
|         self.assertEqual(a[2, 2].item(), 4) | ||||
|  | ||||
|     def test_slice_negative_step(self): | ||||
|         a_np = np.arange(20) | ||||
|   | ||||
| @@ -73,6 +73,12 @@ TEST_CASE("test flatten") { | ||||
|   // Check start > end throws | ||||
|   CHECK_THROWS(flatten(x, 2, 1)); | ||||
|  | ||||
|   // Check start >= ndim throws | ||||
|   CHECK_THROWS(flatten(x, 5, 6)); | ||||
|  | ||||
|   // Check end < 0 throws | ||||
|   CHECK_THROWS(flatten(x, -5, -4)); | ||||
|  | ||||
|   // Check scalar flattens to 1D | ||||
|   x = array(1); | ||||
|   CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1})); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun