feat: Update pre-commit-config.yaml (#667)

This commit is contained in:
Nripesh Niketan 2024-02-11 18:08:20 +04:00 committed by GitHub
parent 06072601ce
commit 0dbc4c7547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 14 deletions

View File

@ -5,7 +5,7 @@ repos:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1 rev: 24.1.1
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort

View File

@ -80,10 +80,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter) _filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes: if args.mlx_dtypes:
compare_filtered = ( compare_filtered = lambda x: (
lambda x: compare_mlx_dtypes( compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x) if _filter(x)
else None else None
) )

View File

@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors" self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
) )
save_dict = { save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) "test": (
if dt in ["float32", "float16", "bfloat16"] mx.random.normal(shape=shape, dtype=getattr(mx, dt))
else mx.ones(shape, dtype=getattr(mx, dt)) if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
} }
with open(save_file_mlx, "wb") as f: with open(save_file_mlx, "wb") as f:
@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
self.test_dir, f"mlx_{dt}_{i}_fs.gguf" self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
) )
save_dict = { save_dict = {
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt)) "test": (
if dt in ["float32", "float16", "bfloat16"] mx.random.normal(shape=shape, dtype=getattr(mx, dt))
else mx.ones(shape, dtype=getattr(mx, dt)) if dt in ["float32", "float16", "bfloat16"]
else mx.ones(shape, dtype=getattr(mx, dt))
)
} }
mx.save_gguf(save_file_mlx, save_dict) mx.save_gguf(save_file_mlx, save_dict)

View File

@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase):
for d in dims: for d in dims:
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
for n_bsx in range(d): for n_bsx in range(d):
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape( bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
[size] * n_bsx
)
for _ in range(trial_mul * d): for _ in range(trial_mul * d):
amlx = mx.array(anp) amlx = mx.array(anp)
bmlx = mx.array(bnp) bmlx = mx.array(bnp)