mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
single sdpa function
This commit is contained in:
@@ -6,8 +6,7 @@ from typing import Any, Dict, Optional, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import QuantizedKVCache, quantized_scaled_dot_product_attention
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -191,20 +190,9 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
output = quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
Reference in New Issue
Block a user