diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index e33a24d91..d4fd755c4 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -348,3 +348,50 @@ def he_uniform( return mx.random.uniform(-limit, limit, a.shape, dtype=dtype) return initializer + + +def sparse( + sparsity: float, + mean: float = 0.0, + std: float = 1.0, + dtype: mx.Dtype = mx.float32, +) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns a sparse matrix. + + Args: + sparsity (float): The fraction of elements in each column to be set to + zero. + mean (float, optional): Mean of the normal distribution. Default: + ``0.0``. + std (float, optional): Standard deviation of the normal distribution. + Default: ``1.0``. + dtype (Dtype, optional): The data type of the array. Default: + ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns an array with the + same shape as the input, filled with samples from a normal distribution. + + Example: + + >>> init_fn = nn.init.sparse(sparsity=0.5) + >>> init_fn(mx.zeros((2, 2))) + array([[-1.91187, -0.117483], + [0, 0]], dtype=float32) + """ + + def initializer(a: mx.array) -> mx.array: + if a.ndim != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = a.shape + num_zeros = int(mx.ceil(sparsity * cols)) + + order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1) + a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) + + a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0 + + return a + + return initializer diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 3cc63e03d..f2fa179fd 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase): self.assertEqual(result.shape, shape) self.assertEqual(result.dtype, dtype) + def test_sparse(self): + mean = 0.0 + std = 1.0 + sparsity = 0.5 + for dtype in [mx.float32, mx.float16]: + initializer = init.sparse(sparsity, mean, std, dtype=dtype) + for shape in [(3, 2), (2, 2), (4, 3)]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + self.assertEqual( + (mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True + ) + with self.assertRaises(ValueError): + result = initializer(mx.zeros((1,))) + if __name__ == "__main__": unittest.main()