mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix layer norm
This commit is contained in:
parent
4ec2b6eec3
commit
f26e81ccc9
@ -12,7 +12,7 @@ def run(t5_model: str):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||||
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
||||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||||
torch_forward = torch_model(**torch_tokens)
|
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
|
||||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||||
|
|
||||||
print("\n TF BERT:")
|
print("\n TF BERT:")
|
||||||
|
17
t5/t5.py
17
t5/t5.py
@ -157,21 +157,16 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if affine:
|
self.weight = mx.ones((dims,))
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
|
|
||||||
def _extra_repr(self):
|
|
||||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
means = mx.mean(x, axis=-1, keepdims=True)
|
var = x.var(axis=-1, keepdims=True)
|
||||||
var = mx.var(x, axis=-1, keepdims=True)
|
x = x * mx.rsqrt(var + self.eps)
|
||||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
return x * self.weight
|
||||||
return (self.weight * x) if "weight" in self else x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
@ -392,7 +387,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.encode_only:
|
if args.encode_only:
|
||||||
print("[INFO] Encoding with T5...", flush=True)
|
print("[INFO] Encoding with T5...", flush=True)
|
||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, flush=True)
|
||||||
embeddings = model.wte(prompt)
|
embeddings = model.wte(prompt)
|
||||||
encoder_output = model.encoder(embeddings, mask=None)
|
encoder_output = model.encoder(embeddings, mask=None)
|
||||||
print(encoder_output, flush=True)
|
print(encoder_output, flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user