mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix attention for 3b model
This commit is contained in:
parent
dbb4d6aea6
commit
72581e5c1a
9
t5/t5.py
9
t5/t5.py
@ -93,11 +93,12 @@ class RelativePositionBias(nn.Module):
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user