mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Fixed model returning contacts and updated README.md
This commit is contained in:
@@ -78,10 +78,9 @@ final_layer = representations[33] # Shape: (batch, length, embed_dim)
|
||||
# Predict residue-residue contacts
|
||||
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
|
||||
|
||||
# Or get contacts along with other outputs
|
||||
result = model(tokens, return_contacts=True)
|
||||
contacts = result["contacts"]
|
||||
attentions = result["attentions"] # Shape: (batch, layers, heads, length, length)
|
||||
# Or compute contacts together with logits, representations, etc.
|
||||
outputs = model(tokens, return_contacts=True)
|
||||
contacts = outputs["contacts"]
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
@@ -205,9 +205,10 @@ class ESM2(nn.Module):
|
||||
|
||||
result["attentions"] = attentions
|
||||
|
||||
if return_contacts:
|
||||
contacts = self.contact_head(tokens, attentions)
|
||||
result["contacts"] = contacts
|
||||
# Compute contacts if requested
|
||||
if return_contacts:
|
||||
contacts = self.contact_head(tokens, attentions)
|
||||
result["contacts"] = contacts
|
||||
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user