bump pre commit and fix format (#373)

This commit is contained in:
Awni Hannun 2024-01-04 16:28:52 -08:00 committed by GitHub
parent c82a8cc526
commit b9e415d19c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2 additions and 10 deletions

View File

@ -5,7 +5,7 @@ repos:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 22.10.0 rev: 23.12.1
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort

View File

@ -122,7 +122,6 @@ class TestBF16(mlx_tests.MLXTestCase):
for op in ("min", "max"): for op in ("min", "max"):
with self.subTest(op=op): with self.subTest(op=op):
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
with self.subTest(axes=axes): with self.subTest(axes=axes):
np_args = (x.astype(np.float32),) np_args = (x.astype(np.float32),)

View File

@ -80,7 +80,6 @@ class TestBlas(mlx_tests.MLXTestCase):
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
for B, M, N, K in shapes: for B, M, N, K in shapes:
with self.subTest(transpose="nn"): with self.subTest(transpose="nn"):
shape_a = (B, M, K) shape_a = (B, M, K)
shape_b = (B, K, N) shape_b = (B, K, N)
@ -151,7 +150,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
def test_matmul_dtypes(self): def test_matmul_dtypes(self):
for dt in self.dtypes: for dt in self.dtypes:
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype( a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
getattr(np, dt) getattr(np, dt)

View File

@ -34,7 +34,6 @@ class TestLoad(mlx_tests.MLXTestCase):
cls.test_dir_fid.cleanup() cls.test_dir_fid.cleanup()
def test_save_and_load(self): def test_save_and_load(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir) os.mkdir(self.test_dir)
@ -92,7 +91,6 @@ class TestLoad(mlx_tests.MLXTestCase):
) )
def test_save_and_load_fs(self): def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir) os.mkdir(self.test_dir)
@ -165,7 +163,6 @@ class TestLoad(mlx_tests.MLXTestCase):
(save_file_npy_uncomp, save_file_mlx_uncomp), (save_file_npy_uncomp, save_file_mlx_uncomp),
(save_file_npy_comp, save_file_mlx_comp), (save_file_npy_comp, save_file_mlx_comp),
): ):
# Load array saved by mlx as mlx array # Load array saved by mlx as mlx array
load_arr_mlx_mlx = mx.load(save_file_mlx) load_arr_mlx_mlx = mx.load(save_file_mlx)
for k, v in load_arr_mlx_mlx.items(): for k, v in load_arr_mlx_mlx.items():

View File

@ -65,7 +65,6 @@ class TestReduce(mlx_tests.MLXTestCase):
for op in ("sum", "prod", "min", "max"): for op in ("sum", "prod", "min", "max"):
with self.subTest(op=op): with self.subTest(op=op):
np_op = getattr(np, op) np_op = getattr(np, op)
mlx_op = getattr(mx, op) mlx_op = getattr(mx, op)
@ -96,7 +95,6 @@ class TestReduce(mlx_tests.MLXTestCase):
] ]
for dtype in dtypes: for dtype in dtypes:
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
x = mx.array(data) x = mx.array(data)
for op in ["argmin", "argmax"]: for op in ["argmin", "argmax"]: