mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
[Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
This commit is contained in:

committed by
GitHub

parent
3d17077187
commit
430ffef58a
@@ -348,3 +348,50 @@ def he_uniform(
|
||||
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def sparse(
|
||||
sparsity: float,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array], mx.array]:
|
||||
r"""An initializer that returns a sparse matrix.
|
||||
|
||||
Args:
|
||||
sparsity (float): The fraction of elements in each column to be set to
|
||||
zero.
|
||||
mean (float, optional): Mean of the normal distribution. Default:
|
||||
``0.0``.
|
||||
std (float, optional): Standard deviation of the normal distribution.
|
||||
Default: ``1.0``.
|
||||
dtype (Dtype, optional): The data type of the array. Default:
|
||||
``float32``.
|
||||
|
||||
Returns:
|
||||
Callable[[array], array]: An initializer that returns an array with the
|
||||
same shape as the input, filled with samples from a normal distribution.
|
||||
|
||||
Example:
|
||||
|
||||
>>> init_fn = nn.init.sparse(sparsity=0.5)
|
||||
>>> init_fn(mx.zeros((2, 2)))
|
||||
array([[-1.91187, -0.117483],
|
||||
[0, 0]], dtype=float32)
|
||||
"""
|
||||
|
||||
def initializer(a: mx.array) -> mx.array:
|
||||
if a.ndim != 2:
|
||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||
|
||||
rows, cols = a.shape
|
||||
num_zeros = int(mx.ceil(sparsity * cols))
|
||||
|
||||
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
|
||||
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||
|
||||
a[mx.arange(rows).reshape(rows, 1), order[:, :num_zeros]] = 0
|
||||
|
||||
return a
|
||||
|
||||
return initializer
|
||||
|
Reference in New Issue
Block a user