mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -234,7 +234,7 @@ def glorot_uniform(
|
||||
|
||||
def he_normal(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""Build a He normal initializer.
|
||||
|
||||
This initializer samples from a normal distribution with a standard
|
||||
@@ -292,7 +292,7 @@ def he_normal(
|
||||
|
||||
def he_uniform(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""A He uniform (Kaiming uniform) initializer.
|
||||
|
||||
This initializer samples from a uniform distribution with a range
|
||||
|
||||
Reference in New Issue
Block a user