mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
bump pre commit and fix format (#373)
This commit is contained in:
parent
c82a8cc526
commit
b9e415d19c
@ -5,7 +5,7 @@ repos:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 22.10.0
|
||||
rev: 23.12.1
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@ -139,7 +139,7 @@ Device::~Device() {
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
|
||||
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
|
@ -122,7 +122,6 @@ class TestBF16(mlx_tests.MLXTestCase):
|
||||
|
||||
for op in ("min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||
with self.subTest(axes=axes):
|
||||
np_args = (x.astype(np.float32),)
|
||||
|
@ -80,7 +80,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
|
||||
with self.subTest(transpose="nn"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, K, N)
|
||||
@ -151,7 +150,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_dtypes(self):
|
||||
|
||||
for dt in self.dtypes:
|
||||
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
|
@ -34,7 +34,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
@ -92,7 +91,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
@ -165,7 +163,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||
(save_file_npy_comp, save_file_mlx_comp),
|
||||
):
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_mlx.items():
|
||||
|
@ -65,7 +65,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
|
||||
for op in ("sum", "prod", "min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
np_op = getattr(np, op)
|
||||
mlx_op = getattr(mx, op)
|
||||
|
||||
@ -96,7 +95,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
]
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
|
||||
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
|
||||
x = mx.array(data)
|
||||
for op in ["argmin", "argmax"]:
|
||||
|
Loading…
Reference in New Issue
Block a user