diff --git a/t5/hf_t5.py b/t5/hf_t5.py index a1910afb..cb75c0f6 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -12,7 +12,7 @@ def run(t5_model: str): tokenizer = AutoTokenizer.from_pretrained(t5_model) torch_model = T5EncoderModel.from_pretrained(t5_model) torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) - torch_forward = torch_model(**torch_tokens) + torch_forward = torch_model(**torch_tokens, output_hidden_states=True) torch_output = torch_forward.last_hidden_state.detach().numpy() print("\n TF BERT:") diff --git a/t5/t5.py b/t5/t5.py index 63222eeb..e955e372 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -157,21 +157,16 @@ class MultiHeadAttention(nn.Module): class LayerNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): + def __init__(self, dims: int, eps: float = 1e-5): super().__init__() - if affine: - self.weight = mx.ones((dims,)) + self.weight = mx.ones((dims,)) self.eps = eps self.dims = dims - def _extra_repr(self): - return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" - def __call__(self, x): - means = mx.mean(x, axis=-1, keepdims=True) - var = mx.var(x, axis=-1, keepdims=True) - x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x) if "weight" in self else x + var = x.var(axis=-1, keepdims=True) + x = x * mx.rsqrt(var + self.eps) + return x * self.weight class TransformerEncoderLayer(nn.Module): @@ -392,7 +387,7 @@ if __name__ == "__main__": if args.encode_only: print("[INFO] Encoding with T5...", flush=True) - print(args.prompt, end="", flush=True) + print(args.prompt, flush=True) embeddings = model.wte(prompt) encoder_output = model.encoder(embeddings, mask=None) print(encoder_output, flush=True)