mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix nd ternary on GPU (#1746)
This commit is contained in:
		| @@ -1755,6 +1755,14 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             np.where, | ||||
|         ) | ||||
|  | ||||
|         # Check non-contiguous input with several dimensions | ||||
|         shape = [1, 2, 2, 3, 3, 1] | ||||
|         strides = [16, 4, 1, 4, 1, 1] | ||||
|         x = mx.ones(shape=(1, 4, 4, 1)) | ||||
|         x = mx.as_strided(x, shape, strides) | ||||
|         out = mx.where(mx.isnan(x), mx.nan, x) | ||||
|         self.assertTrue(mx.allclose(out, mx.ones_like(out))) | ||||
|  | ||||
|     def test_nan_to_num(self): | ||||
|         a = mx.array([6, float("inf"), 2, 0]) | ||||
|         out_mx = mx.nan_to_num(a) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun