diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e40332dd..bd11dcf0 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -152,7 +152,7 @@ def setup_arg_parser(): "--num-draft-tokens", type=int, help="Number of tokens to draft when using speculative decoding.", - default=2, + default=3, ) return parser diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..8b40effb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -33,13 +33,13 @@ def create_causal_mask( linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] rinds = rinds[None] - mask = linds < rinds + mask = linds >= rinds if window_size is not None: - mask = mask | (linds > rinds + window_size) + mask = mask & (linds <= rinds + window_size) if lengths is not None: lengths = lengths[:, None, None, None] - mask = mask | (rinds >= lengths) - return mask * -1e9 + mask = mask & (rinds < lengths) + return mask def create_attention_mask(h: mx.array, cache: Optional[Any] = None): @@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: offset = c.offset mask = create_causal_mask(T, offset, window_size=window_size) - mask = mask.astype(h.dtype) else: mask = None return mask