mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add load_safe to the general conv loaders (#2258)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							095163b8d1
						
					
				
				
					commit
					8590c0941e
				
			@@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
 | 
			
		||||
 | 
			
		||||
    def test_conv2d_unaligned_channels(self):
 | 
			
		||||
        x = mx.random.uniform(shape=(2, 16, 16, 21))
 | 
			
		||||
        w = mx.random.uniform(shape=(32, 3, 3, 21))
 | 
			
		||||
        y = mx.conv2d(x, w, stream=mx.cpu)
 | 
			
		||||
        y_hat = mx.conv2d(x, w)
 | 
			
		||||
        self.assertTrue(mx.allclose(y, y_hat))
 | 
			
		||||
 | 
			
		||||
        x = mx.random.uniform(shape=(2, 16, 16, 21))
 | 
			
		||||
        w = mx.random.uniform(shape=(21, 3, 3, 21))
 | 
			
		||||
        y = mx.conv2d(x, w, stream=mx.cpu)
 | 
			
		||||
        y_hat = mx.conv2d(x, w)
 | 
			
		||||
        self.assertTrue(mx.allclose(y, y_hat))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user