Merge branch 'ml-explore:main' into sets_of_hf_datasets

This commit is contained in:
Chime Ogbuji 2024-11-08 18:47:27 -05:00 committed by GitHub
commit f9936f77da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 357 additions and 291 deletions

View File

@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example,
generator object which streams the output text, token, and log probabilities.
For example,
```python
from mlx_lm import load, stream_generate
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
```
@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),

View File

@ -152,6 +152,7 @@ def main():
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
@ -165,14 +166,13 @@ def main():
)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
@ -66,10 +74,11 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
for response, *_ in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,

View File

@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
@ -52,6 +54,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--prompt",
"-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
@ -68,6 +71,15 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
@ -247,6 +259,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,

View File

@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset

View File

@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache):
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]

View File

@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:

View File

@ -1,10 +1,83 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx
def make_sampler(
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
else:
return lambda x: categorical_sampling(x, temp)
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = []
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logits: mx.array,
@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp))
def make_repetition_penalty(penalty: float, context_size: int = 20):
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, float):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
return logits
return repetition_penalty_processor

View File

@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
from .utils import load, stream_generate
def get_system_fingerprint():
@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1)
return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method
prompt = endpoints[self.path]()
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
self.handle_completion(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
finish_reason = "length"
stop_sequence_suffix = None
logging.debug(f"Starting completion:")
if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=mx.array(prompt),
text = ""
tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate(
stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment
logging.debug(text)
tokens.append(token)
if self.logprobs > 0:
@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
text = text[: -len(stop_sequence_suffix)]
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)]
)
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def handle_stream(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.
Args:
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
if self.stream:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
)
break
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
):
continue
new_text = detokenizer.last_segment
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
last_segment = detokenizer.last_segment
if last_segment:
if stop_sequence_suffix is not None:
last_segment = last_segment[: -len(stop_sequence_suffix)]
response = self.generate_response(last_segment, "length")
gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def completion_usage_response(
self,

View File

@ -6,12 +6,6 @@ from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
# Replace bytes with their value
for i in range(len(self.tokenmap)):
if self.tokenmap[i].startswith("<0x"):
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
if value.startswith("<0x"):
# Replace bytes with their value
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
else:
self.tokenmap[tokenid] = value.encode()
self.reset()
def reset(self):
self.offset = 0
self._unflushed = ""
self._unflushed = b""
self.text = ""
self.tokens = []
def _flush(self):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
def add_token(self, token):
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
if v.startswith(self._sep):
self._flush()
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
self._flush()
self._unflushed = b""
class BPEStreamingDetokenizer(StreamingDetokenizer):

View File

@ -285,7 +285,7 @@ def train(
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "

View File

@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
@ -29,10 +29,14 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device())
class ModelNotFoundError(Exception):
def __init__(self, message):
@ -136,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, tokens] = selected_logits
return logits
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
@ -184,7 +165,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
@ -213,7 +194,7 @@ def generate_step(
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
@ -223,53 +204,9 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
logits_processor = logits_processor or []
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompt
tokens = None
@ -282,24 +219,31 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
for processor in logits_processor:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
@ -324,43 +268,51 @@ def generate_step(
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt))
prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
with wired_limit(model, [generation_stream]):
detokenizer.reset()
for n, (token, logits) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
# Yield the last segment if streaming
yield detokenizer.last_segment
detokenizer.add_token(token)
detokenizer.finalize()
yield detokenizer.last_segment
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
def generate(
@ -371,7 +323,7 @@ def generate(
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> str:
"""
Generate a complete response from the model.
@ -397,7 +349,7 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model):
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
@ -415,8 +367,7 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
@ -437,7 +388,7 @@ def generate(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
@ -622,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
using mlx-lm version **{__version__}**.
## Use with mlx

View File

@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello",
max_tokens=5,
verbose=False,
logits_processor=[logits_processor],
logits_processors=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)

View File

@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
if __name__ == "__main__":

View File

@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("こんにちは私の名前はAI")
check(tokens)
tokens = tokenizer.encode("a ,b")
check(tokens)

View File

@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest:
```
```sh
mlx_whisper audio_file.mp3
```
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run
`mlx_whisper -h`.
You can also pipe the audio content of other programs via stdin:
```sh
some-process | mlx_whisper -
```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API
Transcribe audio with:

View File

@ -3,7 +3,7 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Union
from typing import Optional, Union
import mlx.core as mx
import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
"""
Open an audio file and read as mono waveform, resampling as necessary
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
cmd.extend([
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
])
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,9 +2,11 @@
import argparse
import os
import pathlib
import traceback
import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
)
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
parser.add_argument(
"--output-name",
type=str,
default=None,
help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
)
parser.add_argument(
"--output-dir",
"-o",
@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try:
result = transcribe(
audio_path,
audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
writer(result, audio_path, **writer_args)
writer(result, output_name, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc.
import json
import os
import pathlib
import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO
@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
f".{self.extension}"
)
with open(output_path, "w", encoding="utf-8") as f:
with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(