mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix attention for 3b model
This commit is contained in:
parent
dbb4d6aea6
commit
72581e5c1a
9
t5/t5.py
9
t5/t5.py
@ -93,11 +93,12 @@ class RelativePositionBias(nn.Module):
|
|||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
inner_dim = config.d_kv * config.num_heads
|
||||||
self.num_heads = config.num_heads
|
self.num_heads = config.num_heads
|
||||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user