From 0dbc4c754714085af01c3a9bfdcbb6db61b0eda1 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sun, 11 Feb 2024 18:08:20 +0400 Subject: [PATCH] feat: Update pre-commit-config.yaml (#667) --- .pre-commit-config.yaml | 2 +- benchmarks/python/comparative/compare.py | 6 ++---- python/tests/test_load.py | 16 ++++++++++------ python/tests/test_ops.py | 4 +--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0ebc6d48..279ab5c91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: clang-format # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.12.1 + rev: 24.1.1 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index a9d3df22d..5b71cf583 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -80,10 +80,8 @@ if __name__ == "__main__": _filter = make_predicate(args.filter, args.negative_filter) if args.mlx_dtypes: - compare_filtered = ( - lambda x: compare_mlx_dtypes( - x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1] - ) + compare_filtered = lambda x: ( + compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) if _filter(x) else None ) diff --git a/python/tests/test_load.py b/python/tests/test_load.py index ab2645bcf..fdf06041a 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } with open(save_file_mlx, "wb") as f: @@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase): self.test_dir, f"mlx_{dt}_{i}_fs.gguf" ) save_dict = { - "test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) - if dt in ["float32", "float16", "bfloat16"] - else mx.ones(shape, dtype=getattr(mx, dt)) + "test": ( + mx.random.normal(shape=shape, dtype=getattr(mx, dt)) + if dt in ["float32", "float16", "bfloat16"] + else mx.ones(shape, dtype=getattr(mx, dt)) + ) } mx.save_gguf(save_file_mlx, save_dict) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 433890237..edb98032b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase): for d in dims: anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) for n_bsx in range(d): - bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape( - [size] * n_bsx - ) + bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx) for _ in range(trial_mul * d): amlx = mx.array(anp) bmlx = mx.array(bnp)