This commit is contained in:
Awni Hannun 2024-12-12 10:37:26 -08:00 committed by GitHub
parent 135c5818c1
commit 77b42b7c8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 25 deletions

View File

@ -79,10 +79,10 @@ def load_image(image_source):
def prepare_inputs(processor, image, prompt): def prepare_inputs(processor, image, prompt):
if isinstance(image, str): if isinstance(image, str):
image = load_image(image) image = load_image(image)
inputs = processor(prompt, image, return_tensors="np") inputs = processor(image, prompt, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"]) pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
return input_ids, pixel_values return pixel_values, input_ids
def load_model(model_path, tokenizer_config={}): def load_model(model_path, tokenizer_config={}):
@ -126,8 +126,7 @@ def main():
processor, model = load_model(args.model, tokenizer_config) processor, model = load_model(args.model, tokenizer_config)
prompt = codecs.decode(args.prompt, "unicode_escape") prompt = codecs.decode(args.prompt, "unicode_escape")
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
print(prompt) print(prompt)
generated_text = generate_text( generated_text = generate_text(

View File

@ -104,31 +104,21 @@ class LlavaModel(nn.Module):
self, image_features, inputs_embeds, input_ids self, image_features, inputs_embeds, input_ids
): ):
image_token_index = self.config.image_token_index image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape batch_size, num_image_patches, embed_dim = image_features.shape
# Positions of <image> tokens in input_ids, assuming batch size is 1 # Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() image_positions = mx.array(
np.where(input_ids[0] == image_token_index)[0], mx.uint32
)
if len(image_positions) != num_images: if len(image_positions) != num_image_patches:
raise ValueError( raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not " f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})." f" match the number of image patches ({num_image_patches})."
) )
text_segments = [] inputs_embeds[0, image_positions] = image_features
start_idx = 0 return inputs_embeds
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values) input_embddings = self.get_input_embeddings(input_ids, pixel_values)

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import codecs
import json import json
import sys import sys
@ -188,6 +189,8 @@ def main():
elif using_cache: elif using_cache:
tokenizer.chat_template = metadata["chat_template"] tokenizer.chat_template = metadata["chat_template"]
prompt = codecs.decode(args.prompt, "unicode_escape")
if not args.ignore_chat_template and ( if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template") hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None and tokenizer.chat_template is not None
@ -199,7 +202,7 @@ def main():
messages.append( messages.append(
{ {
"role": "user", "role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt, "content": sys.stdin.read() if prompt == "-" else prompt,
} }
) )
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
@ -216,8 +219,6 @@ def main():
add_generation_prompt=True, add_generation_prompt=True,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(