From 77b42b7c8bf27272f0263cd04ce962d20295504f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 10:37:26 -0800 Subject: [PATCH] fix llava (#1149) --- llava/generate.py | 7 +++---- llava/llava.py | 26 ++++++++------------------ llms/mlx_lm/generate.py | 7 ++++--- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index 8067839e..64313858 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -79,10 +79,10 @@ def load_image(image_source): def prepare_inputs(processor, image, prompt): if isinstance(image, str): image = load_image(image) - inputs = processor(prompt, image, return_tensors="np") + inputs = processor(image, prompt, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values + return pixel_values, input_ids def load_model(model_path, tokenizer_config={}): @@ -126,8 +126,7 @@ def main(): processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape") - - input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + pixel_values, input_ids = prepare_inputs(processor, args.image, prompt) print(prompt) generated_text = generate_text( diff --git a/llava/llava.py b/llava/llava.py index 9e6b7511..c5f190f8 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -104,31 +104,21 @@ class LlavaModel(nn.Module): self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape + batch_size, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + image_positions = mx.array( + np.where(input_ids[0] == image_token_index)[0], mx.uint32 + ) - if len(image_positions) != num_images: + if len(image_positions) != num_image_patches: raise ValueError( f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." + f" match the number of image patches ({num_image_patches})." ) - text_segments = [] - start_idx = 0 - - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + inputs_embeds[0, image_positions] = image_features + return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0c1b4acd..84dc63ca 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import codecs import json import sys @@ -188,6 +189,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] + prompt = codecs.decode(args.prompt, "unicode_escape") + if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,7 +202,7 @@ def main(): messages.append( { "role": "user", - "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + "content": sys.stdin.read() if prompt == "-" else prompt, } ) prompt = tokenizer.apply_chat_template( @@ -216,8 +219,6 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - else: - prompt = args.prompt sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate(