Fixed model returning contacts and updated README.md

This commit is contained in:
Vincent Amato
2025-08-16 01:04:18 -04:00
parent 73a89f84b4
commit 97204566ad
2 changed files with 7 additions and 7 deletions

View File

@@ -78,10 +78,9 @@ final_layer = representations[33] # Shape: (batch, length, embed_dim)
# Predict residue-residue contacts
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
# Or get contacts along with other outputs
result = model(tokens, return_contacts=True)
contacts = result["contacts"]
attentions = result["attentions"] # Shape: (batch, layers, heads, length, length)
# Or compute contacts together with logits, representations, etc.
outputs = model(tokens, return_contacts=True)
contacts = outputs["contacts"]
```
### Examples

View File

@@ -205,9 +205,10 @@ class ESM2(nn.Module):
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
# Compute contacts if requested
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result