mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
[Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:
parent
3d17077187
commit
430ffef58a
@ -348,3 +348,50 @@ def he_uniform(
|
|||||||
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
|
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
|
||||||
|
|
||||||
return initializer
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def sparse(
|
||||||
|
sparsity: float,
|
||||||
|
mean: float = 0.0,
|
||||||
|
std: float = 1.0,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns a sparse matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparsity (float): The fraction of elements in each column to be set to
|
||||||
|
zero.
|
||||||
|
mean (float, optional): Mean of the normal distribution. Default:
|
||||||
|
``0.0``.
|
||||||
|
std (float, optional): Standard deviation of the normal distribution.
|
||||||
|
Default: ``1.0``.
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default:
|
||||||
|
``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns an array with the
|
||||||
|
same shape as the input, filled with samples from a normal distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.sparse(sparsity=0.5)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[-1.91187, -0.117483],
|
||||||
|
[0, 0]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array) -> mx.array:
|
||||||
|
if a.ndim != 2:
|
||||||
|
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||||
|
|
||||||
|
rows, cols = a.shape
|
||||||
|
num_zeros = int(mx.ceil(sparsity * cols))
|
||||||
|
|
||||||
|
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
|
||||||
|
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||||
|
|
||||||
|
a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0
|
||||||
|
|
||||||
|
return a
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
@ -89,6 +89,23 @@ class TestInit(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(result.shape, shape)
|
self.assertEqual(result.shape, shape)
|
||||||
self.assertEqual(result.dtype, dtype)
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_sparse(self):
|
||||||
|
mean = 0.0
|
||||||
|
std = 1.0
|
||||||
|
sparsity = 0.5
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.sparse(sparsity, mean, std, dtype=dtype)
|
||||||
|
for shape in [(3, 2), (2, 2), (4, 3)]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
self.assertEqual(
|
||||||
|
(mx.sum(result == 0) >= 0.5 * shape[0] * shape[1]), True
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
result = initializer(mx.zeros((1,)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user