Fix layer norm

This commit is contained in:
Juarez Bochi 2023-12-17 07:47:52 -05:00
parent 4ec2b6eec3
commit f26e81ccc9
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 7 additions and 12 deletions

View File

@ -12,7 +12,7 @@ def run(t5_model: str):
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5EncoderModel.from_pretrained(t5_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
torch_output = torch_forward.last_hidden_state.detach().numpy()
print("\n TF BERT:")

View File

@ -157,21 +157,16 @@ class MultiHeadAttention(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
if affine:
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x) if "weight" in self else x
var = x.var(axis=-1, keepdims=True)
x = x * mx.rsqrt(var + self.eps)
return x * self.weight
class TransformerEncoderLayer(nn.Module):
@ -392,7 +387,7 @@ if __name__ == "__main__":
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
print(args.prompt, end="", flush=True)
print(args.prompt, flush=True)
embeddings = model.wte(prompt)
encoder_output = model.encoder(embeddings, mask=None)
print(encoder_output, flush=True)