mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
[Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:

committed by
GitHub

parent
3d17077187
commit
430ffef58a
@@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, dtype)
|
||||
|
||||
def test_sparse(self):
|
||||
mean = 0.0
|
||||
std = 1.0
|
||||
sparsity = 0.5
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.sparse(sparsity, mean, std, dtype=dtype)
|
||||
for shape in [(3, 2), (2, 2), (4, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, dtype)
|
||||
self.assertEqual(
|
||||
(mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
result = initializer(mx.zeros((1,)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user