mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 05:04:37 +08:00
fix llava (#1149)
This commit is contained in:
@@ -79,10 +79,10 @@ def load_image(image_source):
|
||||
def prepare_inputs(processor, image, prompt):
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
inputs = processor(prompt, image, return_tensors="np")
|
||||
inputs = processor(image, prompt, return_tensors="np")
|
||||
pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
return input_ids, pixel_values
|
||||
return pixel_values, input_ids
|
||||
|
||||
|
||||
def load_model(model_path, tokenizer_config={}):
|
||||
@@ -126,8 +126,7 @@ def main():
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
|
@@ -104,31 +104,21 @@ class LlavaModel(nn.Module):
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
image_positions = mx.array(
|
||||
np.where(input_ids[0] == image_token_index)[0], mx.uint32
|
||||
)
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
if len(image_positions) != num_image_patches:
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images})."
|
||||
f" match the number of image patches ({num_image_patches})."
|
||||
)
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
inputs_embeds[0, image_positions] = image_features
|
||||
return inputs_embeds
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
|
Reference in New Issue
Block a user