From 2b8ace6a039be663cf4aede1c0b7a713e2dd18cc Mon Sep 17 00:00:00 2001 From: LastWhisper Date: Tue, 15 Oct 2024 21:45:46 +0800 Subject: [PATCH] Typing the dropout. (#1479) --- python/mlx/nn/layers/dropout.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 657f8c47a..a6a690996 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -23,10 +23,10 @@ class Dropout(Module): self._p_1 = 1 - p - def _extra_repr(self): + def _extra_repr(self) -> str: return f"p={1-self._p_1}" - def __call__(self, x): + def __call__(self, x: mx.array) -> mx.array: if self._p_1 == 1 or not self.training: return x @@ -66,10 +66,10 @@ class Dropout2d(Module): self._p_1 = 1 - p - def _extra_repr(self): + def _extra_repr(self) -> str: return f"p={1-self._p_1}" - def __call__(self, x): + def __call__(self, x: mx.array) -> mx.array: if x.ndim not in (3, 4): raise ValueError( f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." @@ -115,10 +115,10 @@ class Dropout3d(Module): self._p_1 = 1 - p - def _extra_repr(self): + def _extra_repr(self) -> str: return f"p={1-self._p_1}" - def __call__(self, x): + def __call__(self, x: mx.array) -> mx.array: if x.ndim not in (4, 5): raise ValueError( f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."