diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 35bafe380..4b170e88f 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -314,7 +314,7 @@ class Transformer(Module): norm_first (bool, optional): if ``True``, encoder and decoder layers will perform layer normalization before attention and MLP operations, otherwise after. Default: ``True``. - chekpoint (bool, optional): if ``True`` perform gradient checkpointing + checkpoint (bool, optional): if ``True`` perform gradient checkpointing to reduce the memory usage at the expense of more computation. Default: ``False``. """