2024-03-02 02:28:35 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import codecs
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import requests
|
|
|
|
from PIL import Image
|
|
|
|
from transformers import AutoProcessor
|
|
|
|
|
|
|
|
from llava import LlavaModel
|
|
|
|
|
|
|
|
|
|
|
|
def parse_arguments():
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Generate text from an image using a model."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
|
|
|
default="llava-hf/llava-1.5-7b-hf",
|
|
|
|
help="The path to the local model directory or Hugging Face repo.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--image",
|
|
|
|
type=str,
|
|
|
|
default="http://images.cocodataset.org/val2017/000000039769.jpg",
|
|
|
|
help="URL or path of the image to process.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--prompt",
|
|
|
|
type=str,
|
|
|
|
default="USER: <image>\nWhat are these?\nASSISTANT:",
|
|
|
|
help="Message to be processed by the model.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
type=int,
|
|
|
|
default=100,
|
|
|
|
help="Maximum number of tokens to generate.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp", type=float, default=0.3, help="Temperature for sampling."
|
|
|
|
)
|
2024-05-08 21:04:36 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--eos-token",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="End of sequence token for tokenizer",
|
|
|
|
)
|
2024-03-02 02:28:35 +08:00
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
def load_image(image_source):
|
|
|
|
"""
|
|
|
|
Helper function to load an image from either a URL or file.
|
|
|
|
"""
|
|
|
|
if image_source.startswith(("http://", "https://")):
|
|
|
|
try:
|
|
|
|
response = requests.get(image_source, stream=True)
|
|
|
|
response.raise_for_status()
|
|
|
|
return Image.open(response.raw)
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(
|
|
|
|
f"Failed to load image from URL: {image_source} with error {e}"
|
|
|
|
)
|
|
|
|
elif Path(image_source).is_file():
|
|
|
|
try:
|
|
|
|
return Image.open(image_source)
|
|
|
|
except IOError as e:
|
|
|
|
raise ValueError(f"Failed to load image {image_source} with error: {e}")
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"The image {image_source} must be a valid URL or existing file."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_inputs(processor, image, prompt):
|
|
|
|
if isinstance(image, str):
|
|
|
|
image = load_image(image)
|
2024-12-13 02:37:26 +08:00
|
|
|
inputs = processor(image, prompt, return_tensors="np")
|
2024-03-02 02:28:35 +08:00
|
|
|
pixel_values = mx.array(inputs["pixel_values"])
|
|
|
|
input_ids = mx.array(inputs["input_ids"])
|
2024-12-13 02:37:26 +08:00
|
|
|
return pixel_values, input_ids
|
2024-03-02 02:28:35 +08:00
|
|
|
|
|
|
|
|
2024-05-08 21:04:36 +08:00
|
|
|
def load_model(model_path, tokenizer_config={}):
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config)
|
2024-03-02 02:28:35 +08:00
|
|
|
model = LlavaModel.from_pretrained(model_path)
|
|
|
|
return processor, model
|
|
|
|
|
|
|
|
|
|
|
|
def sample(logits, temperature=0.0):
|
|
|
|
if temperature == 0:
|
|
|
|
return mx.argmax(logits, axis=-1)
|
|
|
|
else:
|
|
|
|
return mx.random.categorical(logits * (1 / temperature))
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
|
|
|
logits, cache = model(input_ids, pixel_values)
|
|
|
|
logits = logits[:, -1, :]
|
|
|
|
y = sample(logits, temperature=temperature)
|
|
|
|
tokens = [y.item()]
|
|
|
|
|
|
|
|
for n in range(max_tokens - 1):
|
|
|
|
logits, cache = model.language_model(y[None], cache=cache)
|
|
|
|
logits = logits[:, -1, :]
|
|
|
|
y = sample(logits, temperature)
|
|
|
|
token = y.item()
|
|
|
|
if token == processor.tokenizer.eos_token_id:
|
|
|
|
break
|
|
|
|
tokens.append(token)
|
|
|
|
|
|
|
|
return processor.tokenizer.decode(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_arguments()
|
2024-05-08 21:04:36 +08:00
|
|
|
|
|
|
|
tokenizer_config = {}
|
|
|
|
if args.eos_token is not None:
|
|
|
|
tokenizer_config["eos_token"] = args.eos_token
|
|
|
|
|
|
|
|
processor, model = load_model(args.model, tokenizer_config)
|
2024-03-02 02:28:35 +08:00
|
|
|
|
|
|
|
prompt = codecs.decode(args.prompt, "unicode_escape")
|
2024-12-13 02:37:26 +08:00
|
|
|
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
|
2024-03-02 02:28:35 +08:00
|
|
|
|
|
|
|
print(prompt)
|
|
|
|
generated_text = generate_text(
|
|
|
|
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
|
|
|
)
|
|
|
|
print(generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|