feat: Update pre-commit-config.yaml (#667)

This commit is contained in:
Nripesh Niketan 2024-02-11 18:08:20 +04:00 committed by GitHub
parent 06072601ce
commit 0dbc4c7547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 14 deletions

View File

@ -5,7 +5,7 @@ repos:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1
rev: 24.1.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@ -80,10 +80,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
if _filter(x)
else None
)

View File

@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
"test": (
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
}
with open(save_file_mlx, "wb") as f:
@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
)
save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
"test": (
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
}
mx.save_gguf(save_file_mlx, save_dict)

View File

@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase):
for d in dims:
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
for n_bsx in range(d):
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape(
[size] * n_bsx
)
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
for _ in range(trial_mul * d):
amlx = mx.array(anp)
bmlx = mx.array(bnp)