mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix layer norm
This commit is contained in:
parent
4ec2b6eec3
commit
f26e81ccc9
@ -12,7 +12,7 @@ def run(t5_model: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||
torch_forward = torch_model(**torch_tokens)
|
||||
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
|
||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||
|
||||
print("\n TF BERT:")
|
||||
|
17
t5/t5.py
17
t5/t5.py
@ -157,21 +157,16 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x) if "weight" in self else x
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = x * mx.rsqrt(var + self.eps)
|
||||
return x * self.weight
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
@ -392,7 +387,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
print(args.prompt, flush=True)
|
||||
embeddings = model.wte(prompt)
|
||||
encoder_output = model.encoder(embeddings, mask=None)
|
||||
print(encoder_output, flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user