mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
catch stream errors earlier to avoid aborts (#1801)
This commit is contained in:
@@ -385,7 +385,7 @@ def sparse(
|
||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||
|
||||
rows, cols = a.shape
|
||||
num_zeros = int(mx.ceil(sparsity * cols))
|
||||
num_zeros = int(math.ceil(sparsity * cols))
|
||||
|
||||
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
|
||||
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||
|
Reference in New Issue
Block a user