mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
simplify
This commit is contained in:
@@ -7,7 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import QuantizedKVCache
|
||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -192,10 +192,10 @@ class Attention(nn.Module):
|
||||
keys = self.rope(keys)
|
||||
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
output = mx.fast.quantized_scaled_dot_product_attention(
|
||||
output = quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
*keys,
|
||||
*values,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
|
Reference in New Issue
Block a user