diff --git a/esm/README.md b/esm/README.md index b8148987..8c28bae3 100644 --- a/esm/README.md +++ b/esm/README.md @@ -78,10 +78,9 @@ final_layer = representations[33] # Shape: (batch, length, embed_dim) # Predict residue-residue contacts contacts = model.predict_contacts(tokens) # Shape: (batch, length, length) -# Or get contacts along with other outputs -result = model(tokens, return_contacts=True) -contacts = result["contacts"] -attentions = result["attentions"] # Shape: (batch, layers, heads, length, length) +# Or compute contacts together with logits, representations, etc. +outputs = model(tokens, return_contacts=True) +contacts = outputs["contacts"] ``` ### Examples diff --git a/esm/esm/model.py b/esm/esm/model.py index 69e28daa..ee90a071 100644 --- a/esm/esm/model.py +++ b/esm/esm/model.py @@ -205,9 +205,10 @@ class ESM2(nn.Module): result["attentions"] = attentions - if return_contacts: - contacts = self.contact_head(tokens, attentions) - result["contacts"] = contacts + # Compute contacts if requested + if return_contacts: + contacts = self.contact_head(tokens, attentions) + result["contacts"] = contacts return result