mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	updated calls to use loc &scale (#643)
This commit is contained in:
		| @@ -60,7 +60,7 @@ def normal( | ||||
|     """ | ||||
|  | ||||
|     def initializer(a: mx.array) -> mx.array: | ||||
|         return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean | ||||
|         return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) | ||||
|  | ||||
|     return initializer | ||||
|  | ||||
| @@ -184,7 +184,7 @@ def glorot_normal( | ||||
|     def initializer(a: mx.array, gain: float = 1.0) -> mx.array: | ||||
|         fan_in, fan_out = _calculate_fan_in_fan_out(a) | ||||
|         std = gain * math.sqrt(2.0 / (fan_in + fan_out)) | ||||
|         return mx.random.normal(shape=a.shape, dtype=dtype) * std | ||||
|         return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) | ||||
|  | ||||
|     return initializer | ||||
|  | ||||
| @@ -285,7 +285,7 @@ def he_normal( | ||||
|             raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") | ||||
|  | ||||
|         std = gain / math.sqrt(fan) | ||||
|         return mx.random.normal(shape=a.shape, dtype=dtype) * std | ||||
|         return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) | ||||
|  | ||||
|     return initializer | ||||
|  | ||||
|   | ||||
| @@ -21,7 +21,7 @@ class Embedding(Module): | ||||
|     def __init__(self, num_embeddings: int, dims: int): | ||||
|         super().__init__() | ||||
|         scale = math.sqrt(1 / dims) | ||||
|         self.weight = mx.random.normal((num_embeddings, dims)) * scale | ||||
|         self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"{self.weight.shape[0]}, {self.weight.shape[1]}" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 LeonEricsson
					LeonEricsson