mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
feat: Update pre-commit-config.yaml (#667)
This commit is contained in:
parent
06072601ce
commit
0dbc4c7547
@ -5,7 +5,7 @@ repos:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.12.1
|
||||
rev: 24.1.1
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@ -80,10 +80,8 @@ if __name__ == "__main__":
|
||||
_filter = make_predicate(args.filter, args.negative_filter)
|
||||
|
||||
if args.mlx_dtypes:
|
||||
compare_filtered = (
|
||||
lambda x: compare_mlx_dtypes(
|
||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||
)
|
||||
compare_filtered = lambda x: (
|
||||
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||
if _filter(x)
|
||||
else None
|
||||
)
|
||||
|
@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
mx.save_gguf(save_file_mlx, save_dict)
|
||||
|
@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for d in dims:
|
||||
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
|
||||
for n_bsx in range(d):
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape(
|
||||
[size] * n_bsx
|
||||
)
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
|
||||
for _ in range(trial_mul * d):
|
||||
amlx = mx.array(anp)
|
||||
bmlx = mx.array(bnp)
|
||||
|
Loading…
Reference in New Issue
Block a user