diff --git a/llava/llava.py b/llava/llava.py index 06e56059..9e6b7511 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -68,11 +68,10 @@ class LlavaModel(nn.Module): input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, ): - if pixel_values is None: - return self.language_model(input_ids) - # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + if pixel_values is None: + return inputs_embeds # Get the ouptut hidden states from the vision model *_, hidden_states = self.vision_tower(