mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
bump pre commit and fix format (#373)
This commit is contained in:
parent
c82a8cc526
commit
b9e415d19c
@ -5,7 +5,7 @@ repos:
|
|||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 22.10.0
|
rev: 23.12.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
|
@ -122,7 +122,6 @@ class TestBF16(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
for op in ("min", "max"):
|
for op in ("min", "max"):
|
||||||
with self.subTest(op=op):
|
with self.subTest(op=op):
|
||||||
|
|
||||||
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||||
with self.subTest(axes=axes):
|
with self.subTest(axes=axes):
|
||||||
np_args = (x.astype(np.float32),)
|
np_args = (x.astype(np.float32),)
|
||||||
|
@ -80,7 +80,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
for B, M, N, K in shapes:
|
for B, M, N, K in shapes:
|
||||||
|
|
||||||
with self.subTest(transpose="nn"):
|
with self.subTest(transpose="nn"):
|
||||||
shape_a = (B, M, K)
|
shape_a = (B, M, K)
|
||||||
shape_b = (B, K, N)
|
shape_b = (B, K, N)
|
||||||
@ -151,7 +150,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||||
|
|
||||||
def test_matmul_dtypes(self):
|
def test_matmul_dtypes(self):
|
||||||
|
|
||||||
for dt in self.dtypes:
|
for dt in self.dtypes:
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||||
getattr(np, dt)
|
getattr(np, dt)
|
||||||
|
@ -34,7 +34,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
cls.test_dir_fid.cleanup()
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
def test_save_and_load(self):
|
def test_save_and_load(self):
|
||||||
|
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
@ -92,7 +91,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_save_and_load_fs(self):
|
def test_save_and_load_fs(self):
|
||||||
|
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
@ -165,7 +163,6 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||||
(save_file_npy_comp, save_file_mlx_comp),
|
(save_file_npy_comp, save_file_mlx_comp),
|
||||||
):
|
):
|
||||||
|
|
||||||
# Load array saved by mlx as mlx array
|
# Load array saved by mlx as mlx array
|
||||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||||
for k, v in load_arr_mlx_mlx.items():
|
for k, v in load_arr_mlx_mlx.items():
|
||||||
|
@ -65,7 +65,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
for op in ("sum", "prod", "min", "max"):
|
for op in ("sum", "prod", "min", "max"):
|
||||||
with self.subTest(op=op):
|
with self.subTest(op=op):
|
||||||
|
|
||||||
np_op = getattr(np, op)
|
np_op = getattr(np, op)
|
||||||
mlx_op = getattr(mx, op)
|
mlx_op = getattr(mx, op)
|
||||||
|
|
||||||
@ -96,7 +95,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
]
|
]
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
with self.subTest(dtype=dtype):
|
with self.subTest(dtype=dtype):
|
||||||
|
|
||||||
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
|
data = np.random.rand(10, 12, 13).astype(getattr(np, dtype))
|
||||||
x = mx.array(data)
|
x = mx.array(data)
|
||||||
for op in ["argmin", "argmax"]:
|
for op in ["argmin", "argmax"]:
|
||||||
|
Loading…
Reference in New Issue
Block a user