Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun
2024-08-28 11:16:19 -07:00
committed by GitHub
parent bd47e1f066
commit 291cf40aca
15 changed files with 152 additions and 145 deletions

View File

@@ -234,7 +234,7 @@ def glorot_uniform(
def he_normal(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""Build a He normal initializer.
This initializer samples from a normal distribution with a standard
@@ -292,7 +292,7 @@ def he_normal(
def he_uniform(
dtype: mx.Dtype = mx.float32,
) -> Callable[[mx.array, str, float], mx.array]:
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
r"""A He uniform (Kaiming uniform) initializer.
This initializer samples from a uniform distribution with a range