mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	bump pre commit and fix format (#373)
This commit is contained in:
		| @@ -5,7 +5,7 @@ repos: | |||||||
|     -   id: clang-format |     -   id: clang-format | ||||||
| # Using this mirror lets us use mypyc-compiled black, which is about 2x faster | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster | ||||||
| -   repo: https://github.com/psf/black-pre-commit-mirror | -   repo: https://github.com/psf/black-pre-commit-mirror | ||||||
|     rev: 22.10.0 |     rev: 23.12.1 | ||||||
|     hooks: |     hooks: | ||||||
|     -   id: black |     -   id: black | ||||||
| -   repo: https://github.com/pycqa/isort | -   repo: https://github.com/pycqa/isort | ||||||
|   | |||||||
| @@ -139,7 +139,7 @@ Device::~Device() { | |||||||
|  |  | ||||||
| void Device::new_queue(int index) { | void Device::new_queue(int index) { | ||||||
|   auto thread_pool = metal::new_scoped_memory_pool(); |   auto thread_pool = metal::new_scoped_memory_pool(); | ||||||
|    |  | ||||||
|   // Multiple threads can ask the device for queues |   // Multiple threads can ask the device for queues | ||||||
|   // We lock this as a critical section for safety |   // We lock this as a critical section for safety | ||||||
|   const std::lock_guard<std::mutex> lock(mtx_); |   const std::lock_guard<std::mutex> lock(mtx_); | ||||||
|   | |||||||
| @@ -122,7 +122,6 @@ class TestBF16(mlx_tests.MLXTestCase): | |||||||
|  |  | ||||||
|         for op in ("min", "max"): |         for op in ("min", "max"): | ||||||
|             with self.subTest(op=op): |             with self.subTest(op=op): | ||||||
|  |  | ||||||
|                 for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): |                 for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): | ||||||
|                     with self.subTest(axes=axes): |                     with self.subTest(axes=axes): | ||||||
|                         np_args = (x.astype(np.float32),) |                         np_args = (x.astype(np.float32),) | ||||||
|   | |||||||
| @@ -80,7 +80,6 @@ class TestBlas(mlx_tests.MLXTestCase): | |||||||
|             np_dtype = getattr(np, dtype) |             np_dtype = getattr(np, dtype) | ||||||
|  |  | ||||||
|             for B, M, N, K in shapes: |             for B, M, N, K in shapes: | ||||||
|  |  | ||||||
|                 with self.subTest(transpose="nn"): |                 with self.subTest(transpose="nn"): | ||||||
|                     shape_a = (B, M, K) |                     shape_a = (B, M, K) | ||||||
|                     shape_b = (B, K, N) |                     shape_b = (B, K, N) | ||||||
| @@ -151,7 +150,6 @@ class TestBlas(mlx_tests.MLXTestCase): | |||||||
|         self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6)) |         self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6)) | ||||||
|  |  | ||||||
|     def test_matmul_dtypes(self): |     def test_matmul_dtypes(self): | ||||||
|  |  | ||||||
|         for dt in self.dtypes: |         for dt in self.dtypes: | ||||||
|             a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( |             a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( | ||||||
|                 getattr(np, dt) |                 getattr(np, dt) | ||||||
|   | |||||||
| @@ -34,7 +34,6 @@ class TestLoad(mlx_tests.MLXTestCase): | |||||||
|         cls.test_dir_fid.cleanup() |         cls.test_dir_fid.cleanup() | ||||||
|  |  | ||||||
|     def test_save_and_load(self): |     def test_save_and_load(self): | ||||||
|  |  | ||||||
|         if not os.path.isdir(self.test_dir): |         if not os.path.isdir(self.test_dir): | ||||||
|             os.mkdir(self.test_dir) |             os.mkdir(self.test_dir) | ||||||
|  |  | ||||||
| @@ -92,7 +91,6 @@ class TestLoad(mlx_tests.MLXTestCase): | |||||||
|                         ) |                         ) | ||||||
|  |  | ||||||
|     def test_save_and_load_fs(self): |     def test_save_and_load_fs(self): | ||||||
|  |  | ||||||
|         if not os.path.isdir(self.test_dir): |         if not os.path.isdir(self.test_dir): | ||||||
|             os.mkdir(self.test_dir) |             os.mkdir(self.test_dir) | ||||||
|  |  | ||||||
| @@ -165,7 +163,6 @@ class TestLoad(mlx_tests.MLXTestCase): | |||||||
|                     (save_file_npy_uncomp, save_file_mlx_uncomp), |                     (save_file_npy_uncomp, save_file_mlx_uncomp), | ||||||
|                     (save_file_npy_comp, save_file_mlx_comp), |                     (save_file_npy_comp, save_file_mlx_comp), | ||||||
|                 ): |                 ): | ||||||
|  |  | ||||||
|                     # Load array saved by mlx as mlx array |                     # Load array saved by mlx as mlx array | ||||||
|                     load_arr_mlx_mlx = mx.load(save_file_mlx) |                     load_arr_mlx_mlx = mx.load(save_file_mlx) | ||||||
|                     for k, v in load_arr_mlx_mlx.items(): |                     for k, v in load_arr_mlx_mlx.items(): | ||||||
|   | |||||||
| @@ -65,7 +65,6 @@ class TestReduce(mlx_tests.MLXTestCase): | |||||||
|  |  | ||||||
|                 for op in ("sum", "prod", "min", "max"): |                 for op in ("sum", "prod", "min", "max"): | ||||||
|                     with self.subTest(op=op): |                     with self.subTest(op=op): | ||||||
|  |  | ||||||
|                         np_op = getattr(np, op) |                         np_op = getattr(np, op) | ||||||
|                         mlx_op = getattr(mx, op) |                         mlx_op = getattr(mx, op) | ||||||
|  |  | ||||||
| @@ -96,7 +95,6 @@ class TestReduce(mlx_tests.MLXTestCase): | |||||||
|         ] |         ] | ||||||
|         for dtype in dtypes: |         for dtype in dtypes: | ||||||
|             with self.subTest(dtype=dtype): |             with self.subTest(dtype=dtype): | ||||||
|  |  | ||||||
|                 data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) |                 data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) | ||||||
|                 x = mx.array(data) |                 x = mx.array(data) | ||||||
|                 for op in ["argmin", "argmax"]: |                 for op in ["argmin", "argmax"]: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun