mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantized KV Cache (#1075)
* add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims
This commit is contained in:
@@ -7,7 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,8 +61,8 @@ class Attention(nn.Module):
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
Reference in New Issue
Block a user