diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 49417a913..1f2ffd3da 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -219,7 +219,7 @@ class Upsample(Module): def __init__( self, scale_factor: Union[float, Tuple], - mode: Literal["nearest", "linear"] = "nearest", + mode: Literal["nearest", "linear", "cubic"] = "nearest", align_corners: bool = False, ): super().__init__()