mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 18:39:45 +08:00
updated calls to use loc &scale (#643)
This commit is contained in:
parent
1b97b2958b
commit
7dccd42133
@ -60,7 +60,7 @@ def normal(
|
||||
"""
|
||||
|
||||
def initializer(a: mx.array) -> mx.array:
|
||||
return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean
|
||||
return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
@ -184,7 +184,7 @@ def glorot_normal(
|
||||
def initializer(a: mx.array, gain: float = 1.0) -> mx.array:
|
||||
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
|
||||
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
@ -285,7 +285,7 @@ def he_normal(
|
||||
raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out")
|
||||
|
||||
std = gain / math.sqrt(fan)
|
||||
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
|
@ -21,7 +21,7 @@ class Embedding(Module):
|
||||
def __init__(self, num_embeddings: int, dims: int):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / dims)
|
||||
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
||||
self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
||||
|
Loading…
Reference in New Issue
Block a user