mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	bump pre commit and fix format (#373)
This commit is contained in:
		| @@ -122,7 +122,6 @@ class TestBF16(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         for op in ("min", "max"): | ||||
|             with self.subTest(op=op): | ||||
|  | ||||
|                 for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): | ||||
|                     with self.subTest(axes=axes): | ||||
|                         np_args = (x.astype(np.float32),) | ||||
|   | ||||
| @@ -80,7 +80,6 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|             np_dtype = getattr(np, dtype) | ||||
|  | ||||
|             for B, M, N, K in shapes: | ||||
|  | ||||
|                 with self.subTest(transpose="nn"): | ||||
|                     shape_a = (B, M, K) | ||||
|                     shape_b = (B, K, N) | ||||
| @@ -151,7 +150,6 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6)) | ||||
|  | ||||
|     def test_matmul_dtypes(self): | ||||
|  | ||||
|         for dt in self.dtypes: | ||||
|             a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( | ||||
|                 getattr(np, dt) | ||||
|   | ||||
| @@ -34,7 +34,6 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|         cls.test_dir_fid.cleanup() | ||||
|  | ||||
|     def test_save_and_load(self): | ||||
|  | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
| @@ -92,7 +91,6 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                         ) | ||||
|  | ||||
|     def test_save_and_load_fs(self): | ||||
|  | ||||
|         if not os.path.isdir(self.test_dir): | ||||
|             os.mkdir(self.test_dir) | ||||
|  | ||||
| @@ -165,7 +163,6 @@ class TestLoad(mlx_tests.MLXTestCase): | ||||
|                     (save_file_npy_uncomp, save_file_mlx_uncomp), | ||||
|                     (save_file_npy_comp, save_file_mlx_comp), | ||||
|                 ): | ||||
|  | ||||
|                     # Load array saved by mlx as mlx array | ||||
|                     load_arr_mlx_mlx = mx.load(save_file_mlx) | ||||
|                     for k, v in load_arr_mlx_mlx.items(): | ||||
|   | ||||
| @@ -65,7 +65,6 @@ class TestReduce(mlx_tests.MLXTestCase): | ||||
|  | ||||
|                 for op in ("sum", "prod", "min", "max"): | ||||
|                     with self.subTest(op=op): | ||||
|  | ||||
|                         np_op = getattr(np, op) | ||||
|                         mlx_op = getattr(mx, op) | ||||
|  | ||||
| @@ -96,7 +95,6 @@ class TestReduce(mlx_tests.MLXTestCase): | ||||
|         ] | ||||
|         for dtype in dtypes: | ||||
|             with self.subTest(dtype=dtype): | ||||
|  | ||||
|                 data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) | ||||
|                 x = mx.array(data) | ||||
|                 for op in ["argmin", "argmax"]: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun