catch stream errors earlier to avoid aborts (#1801)

This commit is contained in:
Awni Hannun
2025-01-27 14:05:43 -08:00
committed by GitHub
parent 28091aa1ff
commit 2235dee906
3 changed files with 29 additions and 7 deletions

View File

@@ -385,7 +385,7 @@ def sparse(
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = a.shape
num_zeros = int(mx.ceil(sparsity * cols))
num_zeros = int(math.ceil(sparsity * cols))
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)