mlx-examples/llms/mlx_lm/models
Kevin Wang c0019c4908
Pad mask with zeros for non-square attention matrices (#715)
* Pad mask with zeros for non-square attention matrices

The current implementation of the mask assumes the attention matrix is square, which is true if there is no cache. However, if one wishes to produce multiple tokens at a time, such as in speculative decoding implementations, a rectangular mask is necessary.

This change pads the bottom of the mask with zeros so multi-token decoding with a cache works correctly.

* Directly create mask instead of padding

* Update llama.py
2024-05-04 16:32:25 -07:00
..
__init__.py Mlx llm package (#301) 2024-01-12 10:25:56 -08:00
base.py Mlx llm package (#301) 2024-01-12 10:25:56 -08:00
cohere.py Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
dbrx.py - Removed unused Python imports (#683) 2024-04-16 07:50:32 -07:00
gemma.py Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
llama.py Pad mask with zeros for non-square attention matrices (#715) 2024-05-04 16:32:25 -07:00
minicpm.py MiniCPM implementation (#685) 2024-04-25 15:29:28 -07:00
mixtral.py Fix argpartition call in Mixtral and other MOES (#676) 2024-04-12 11:00:56 -07:00
olmo.py Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
openelm.py Add support for OpenELM (#719) 2024-04-25 16:49:28 -07:00
phi3.py Add support for phi-3 (#712) 2024-04-23 09:20:00 -07:00
phi.py Switch to fast RMS/LN Norm (#603) 2024-03-23 07:13:51 -07:00
phixtral.py Fix argpartition call in Mixtral and other MOES (#676) 2024-04-12 11:00:56 -07:00
plamo.py Configurable LR schedulers (#604) 2024-03-29 13:41:10 -07:00
qwen2_moe.py Fix lora for qwen moe (#743) 2024-05-02 21:55:09 -07:00
qwen2.py Quantize embedding / Update quantize API (#680) 2024-04-18 18:16:10 -07:00
qwen.py Switch to fast RMS/LN Norm (#603) 2024-03-23 07:13:51 -07:00
stablelm.py Stable lm 2 (#666) 2024-04-08 14:18:55 -07:00
starcoder2.py Fixes Typo in Starcoder2 (#740) 2024-04-29 13:14:45 -07:00