Fix llava model when using text-only prompt

This commit is contained in:
Cheng
2024-09-25 15:31:43 +09:00
parent 9bb2dd62f3
commit 5731995652

View File

@@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if pixel_values is None:
return inputs_embeds
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(