mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
feat: Update pre-commit-config.yaml (#667)
This commit is contained in:
@@ -84,9 +84,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.safetensors"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
@@ -113,9 +115,11 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
"test": (
|
||||
mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
)
|
||||
}
|
||||
|
||||
mx.save_gguf(save_file_mlx, save_dict)
|
||||
|
@@ -1333,9 +1333,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for d in dims:
|
||||
anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d)
|
||||
for n_bsx in range(d):
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape(
|
||||
[size] * n_bsx
|
||||
)
|
||||
bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape([size] * n_bsx)
|
||||
for _ in range(trial_mul * d):
|
||||
amlx = mx.array(anp)
|
||||
bmlx = mx.array(bnp)
|
||||
|
Reference in New Issue
Block a user