diff --git a/llms/README.md b/llms/README.md index 4b18ed1f..497c0277 100644 --- a/llms/README.md +++ b/llms/README.md @@ -38,6 +38,8 @@ To see a description of all the arguments you can do: >>> help(generate) ``` +Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail. + The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py new file mode 100644 index 00000000..af599c1b --- /dev/null +++ b/llms/mlx_lm/examples/generate_response.py @@ -0,0 +1,40 @@ +from mlx_lm import generate, load + +# Specify the checkpoint +checkpoint = "mistralai/Mistral-7B-Instruct-v0.3" + +# Load the corresponding model and tokenizer +model, tokenizer = load(path_or_hf_repo=checkpoint) + +# Specify the prompt and conversation history +prompt = "Why is the sky blue?" +conversation = [{"role": "user", "content": prompt}] + +# Transform the prompt into the chat template +prompt = tokenizer.apply_chat_template( + conversation=conversation, tokenize=False, add_generation_prompt=True +) + +# Specify the maximum number of tokens +max_tokens = 1_000 + +# Specify if tokens and timing information will be printed +verbose = True + +# Some optional arguments for causal language model generation +generation_args = { + "temp": 0.7, + "repetition_penalty": 1.2, + "repetition_context_size": 20, + "top_p": 0.95, +} + +# Generate a response with the specified settings +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=max_tokens, + verbose=verbose, + **generation_args, +) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 32454335..4875f931 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,6 +1,6 @@ mlx>=0.14.1 numpy -transformers>=4.39.3 +transformers[sentencepiece]>=4.39.3 protobuf pyyaml jinja2