mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge branch 'ml-explore:main' into sets_of_hf_datasets
This commit is contained in:
commit
f9936f77da
@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This returns a
|
||||
generator object which streams the output text. For example,
|
||||
generator object which streams the output text, token, and log probabilities.
|
||||
For example,
|
||||
|
||||
```python
|
||||
from mlx_lm import load, stream_generate
|
||||
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(t, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example:
|
||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||
|
||||
Most
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||
|
@ -152,6 +152,7 @@ def main():
|
||||
|
||||
model(y[:step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
mx.metal.clear_cache()
|
||||
processed += min(y.size, step_size)
|
||||
y = y[step_size:]
|
||||
current = time.time()
|
||||
@ -165,14 +166,13 @@ def main():
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@ from .utils import load, stream_generate
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MAX_TOKENS = 256
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
@ -41,6 +42,13 @@ def setup_arg_parser():
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -66,10 +74,11 @@ def main():
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response in stream_generate(
|
||||
for response, *_ in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
prompt_cache=prompt_cache,
|
||||
|
@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_MIN_P = 0.0
|
||||
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
@ -52,6 +54,7 @@ def setup_arg_parser():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default=DEFAULT_PROMPT,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
@ -68,6 +71,15 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-tokens-to-keep",
|
||||
type=float,
|
||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||
help="Minimum tokens to keep for min-p sampling.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
@ -247,6 +259,8 @@ def main():
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
min_p=args.min_p,
|
||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size - 1, c.offset)
|
||||
offset = min(c.max_size, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
|
@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache):
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# The largest size is self.max_size + S to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size + 1
|
||||
trim_size = self._idx - self.max_size
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
|
@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
|
||||
use_conv_bias: bool
|
||||
time_step_rank: int
|
||||
tie_word_embeddings: bool = True
|
||||
use_bcdt_rms: bool = False
|
||||
mixer_rms_eps: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
|
||||
|
||||
if self.time_step_rank == "auto":
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
if self.model_type == "falcon_mamba":
|
||||
self.use_bcdt_rms = True
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
|
||||
self.intermediate_size = args.intermediate_size
|
||||
self.time_step_rank = int(args.time_step_rank)
|
||||
self.use_conv_bias = args.use_conv_bias
|
||||
self.use_bcdt_rms = args.use_bcdt_rms
|
||||
if self.use_bcdt_rms:
|
||||
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||
)
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
delta = nn.softplus(self.dt_proj(delta))
|
||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
if state is not None:
|
||||
|
@ -1,10 +1,83 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def make_sampler(
|
||||
temp: float = 0.0,
|
||||
top_p: float = 0.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Callable[mx.array, mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
|
||||
Returns:
|
||||
Callable[mx.array, mx.array]:
|
||||
A sampler which takes log-probabilities and returns tokens.
|
||||
"""
|
||||
if temp == 0:
|
||||
return lambda x: mx.argmax(x, axis=-1)
|
||||
elif top_p > 0 and top_p < 1.0:
|
||||
return lambda x: top_p_sampling(x, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
return lambda x: categorical_sampling(x, temp)
|
||||
|
||||
|
||||
def make_logits_processors(
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
):
|
||||
"""
|
||||
Make logits processors for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Returns:
|
||||
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||
A list of logits processors. Each processor in the list is a
|
||||
callable which takes an array of tokens and an array of logits
|
||||
and returns the updated logits.
|
||||
"""
|
||||
logits_processors = []
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_processor)
|
||||
|
||||
if repetition_penalty and repetition_penalty != 0.0:
|
||||
logits_processors.append(
|
||||
make_repetition_penalty(repetition_penalty, repetition_context_size)
|
||||
)
|
||||
return logits_processors
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
logits: mx.array,
|
||||
@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def categorical_sampling(logits, temp):
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
|
||||
def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||
"""
|
||||
Make repetition penalty processor.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
context_size (int): The number of previous tokens to use.
|
||||
Default: ``20``.
|
||||
|
||||
Returns:
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
if penalty < 0 or not isinstance(penalty, float):
|
||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
if len(tokens) > 0:
|
||||
tokens = tokens[-context_size:]
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0,
|
||||
selected_logits * penalty,
|
||||
selected_logits / penalty,
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
return repetition_penalty_processor
|
||||
|
@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
from .utils import load, stream_generate
|
||||
|
||||
|
||||
def get_system_fingerprint():
|
||||
@ -64,7 +64,7 @@ def stopping_criteria(
|
||||
end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=1)
|
||||
return StopCondition(stop_met=True, trim_length=0)
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||
if self.max_tokens is None:
|
||||
self.max_tokens = self.body.get("max_tokens", 512)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.temperature = self.body.get("temperature", 0.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
# Call endpoint specific method
|
||||
prompt = endpoints[self.path]()
|
||||
|
||||
# Call method based on response type
|
||||
method = self.handle_stream if self.stream else self.handle_completion
|
||||
method(prompt, stop_id_sequences)
|
||||
self.handle_completion(prompt, stop_id_sequences)
|
||||
|
||||
def validate_model_parameters(self):
|
||||
"""
|
||||
@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||
to the stopping_criteria function
|
||||
"""
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
finish_reason = "length"
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting completion:")
|
||||
if self.stream:
|
||||
self.end_headers()
|
||||
logging.debug(f"Starting stream:")
|
||||
else:
|
||||
logging.debug(f"Starting completion:")
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, logprobs) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=mx.array(prompt),
|
||||
text = ""
|
||||
tic = time.perf_counter()
|
||||
for n, (segment, token, logprobs) in enumerate(
|
||||
stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
|
||||
text += segment
|
||||
logging.debug(text)
|
||||
tokens.append(token)
|
||||
|
||||
if self.logprobs > 0:
|
||||
@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
text = text[: -len(stop_sequence_suffix)]
|
||||
break
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
detokenizer.finalize()
|
||||
text = (
|
||||
detokenizer.text
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
response = self.generate_response(
|
||||
text,
|
||||
finish_reason,
|
||||
len(prompt),
|
||||
len(tokens),
|
||||
token_logprobs=token_logprobs,
|
||||
top_tokens=top_tokens,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||
|
||||
# Send an additional Content-Length header when it is known
|
||||
self.send_header("Content-Length", str(len(response_json)))
|
||||
self.end_headers()
|
||||
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
def handle_stream(
|
||||
self,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate response to prompt and foward it to the client using a Server
|
||||
Sent Events (SSE) stream.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The tokenized prompt
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
||||
the stopping_criteria function
|
||||
"""
|
||||
# No additional headers are needed, call end_headers
|
||||
self.end_headers()
|
||||
|
||||
detokenizer = self.tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
tokens = []
|
||||
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, _) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=mx.array(prompt),
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens,
|
||||
stop_id_sequences,
|
||||
self.tokenizer.eos_token_id,
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
if stop_condition.trim_length:
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
if self.stream:
|
||||
# If the end of tokens overlaps with a stop sequence, generate new
|
||||
# tokens until we know if the stop sequence is hit or not
|
||||
if any(
|
||||
(
|
||||
sequence_overlap(tokens, sequence)
|
||||
for sequence in stop_id_sequences
|
||||
)
|
||||
break
|
||||
|
||||
# If the end of tokens overlaps with a stop sequence, generate new
|
||||
# tokens until we know if the stop sequence is hit or not
|
||||
if any(
|
||||
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
|
||||
):
|
||||
continue
|
||||
|
||||
new_text = detokenizer.last_segment
|
||||
if new_text:
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
):
|
||||
continue
|
||||
elif segment:
|
||||
response = self.generate_response(segment, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
|
||||
# check is there any remaining text to send
|
||||
detokenizer.finalize()
|
||||
last_segment = detokenizer.last_segment
|
||||
if last_segment:
|
||||
if stop_sequence_suffix is not None:
|
||||
last_segment = last_segment[: -len(stop_sequence_suffix)]
|
||||
response = self.generate_response(last_segment, "length")
|
||||
gen_time = time.perf_counter() - tic
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = len(tokens) / gen_time
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
if self.stream:
|
||||
response = self.generate_response(segment, finish_reason)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
else:
|
||||
response = self.generate_response(
|
||||
text,
|
||||
finish_reason,
|
||||
len(prompt),
|
||||
len(tokens),
|
||||
token_logprobs=token_logprobs,
|
||||
top_tokens=top_tokens,
|
||||
tokens=tokens,
|
||||
)
|
||||
response_json = json.dumps(response).encode()
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||
|
||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
|
||||
self.wfile.write("data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
# Send an additional Content-Length header when it is known
|
||||
self.send_header("Content-Length", str(len(response_json)))
|
||||
self.end_headers()
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
def completion_usage_response(
|
||||
self,
|
||||
|
@ -6,12 +6,6 @@ from transformers import AutoTokenizer
|
||||
REPLACEMENT_CHAR = "\ufffd"
|
||||
|
||||
|
||||
def _remove_space(x):
|
||||
if x and x[0] == " ":
|
||||
return x[1:]
|
||||
return x
|
||||
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
||||
def __init__(self, tokenizer, trim_space=True):
|
||||
self.trim_space = trim_space
|
||||
self._sep = "\u2581".encode()
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
||||
for value, tokenid in tokenizer.vocab.items():
|
||||
self.tokenmap[tokenid] = value
|
||||
|
||||
# Replace bytes with their value
|
||||
for i in range(len(self.tokenmap)):
|
||||
if self.tokenmap[i].startswith("<0x"):
|
||||
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
||||
if value.startswith("<0x"):
|
||||
# Replace bytes with their value
|
||||
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
|
||||
else:
|
||||
self.tokenmap[tokenid] = value.encode()
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._unflushed = ""
|
||||
self._unflushed = b""
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _flush(self):
|
||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
|
||||
if not self.text and self.trim_space and text and text[0] == " ":
|
||||
text = text[1:]
|
||||
self.text += text
|
||||
|
||||
def add_token(self, token):
|
||||
v = self.tokenmap[token]
|
||||
if v[0] == "\u2581":
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
if v.startswith(self._sep):
|
||||
self._flush()
|
||||
self._unflushed = v
|
||||
else:
|
||||
self._unflushed += v
|
||||
|
||||
def finalize(self):
|
||||
if self.text or not self.trim_space:
|
||||
self.text += self._unflushed.replace("\u2581", " ")
|
||||
else:
|
||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||
self._unflushed = ""
|
||||
self._flush()
|
||||
self._unflushed = b""
|
||||
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
|
@ -285,7 +285,7 @@ def train(
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
|
@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters
|
||||
@ -29,10 +29,14 @@ from .tuner.utils import load_adapters
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama", # mistral is compatible with llama
|
||||
"phi-msft": "phixtral",
|
||||
"falcon_mamba": "mamba",
|
||||
}
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
# A stream on the default device just for generation
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@ -136,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
return model_path
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||
"""
|
||||
Apply repetition penalty to specific logits based on the given context.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
logits (mx.array): The logits produced by the language model.
|
||||
tokens (mx.array): A list of N previous tokens.
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
|
||||
Returns:
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(tokens) > 0:
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
kv_bits is not None
|
||||
@ -184,7 +165,7 @@ def generate_step(
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
@ -213,7 +194,7 @@ def generate_step(
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
@ -223,53 +204,9 @@ def generate_step(
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
token = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
):
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
logits_processor = logits_processor or []
|
||||
|
||||
if repetition_penalty:
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
return apply_repetition_penalty(
|
||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||
)
|
||||
|
||||
logits_processor.append(repetition_penalty_processor)
|
||||
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processor.append(logit_bias_processor)
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
@ -282,24 +219,31 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = logits_processors or []
|
||||
logits_processors.extend(
|
||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
if logits_processors:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
|
||||
if logits_processor:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
return y, logprobs.squeeze(0)
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
@ -324,43 +268,51 @@ def generate_step(
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Generator[Tuple[str, int, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The ma
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
||||
Tuple[str, int, mx.array]:
|
||||
The next text segment, token, and vector of log probabilities.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
prompt_tokens = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for n, (token, _) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
for n, (token, logits) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Yield the last segment if streaming
|
||||
yield detokenizer.last_segment
|
||||
detokenizer.add_token(token)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield detokenizer.last_segment, token, logits
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment, token, logits
|
||||
|
||||
|
||||
def generate(
|
||||
@ -371,7 +323,7 @@ def generate(
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
@ -397,7 +349,7 @@ def generate(
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model):
|
||||
with wired_limit(model, [generation_stream]):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
@ -415,8 +367,7 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
@ -437,7 +388,7 @@ def generate(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
@ -622,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
f"""
|
||||
# {upload_repo}
|
||||
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
||||
using mlx-lm version **{__version__}**.
|
||||
|
||||
## Use with mlx
|
||||
|
||||
|
@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
logits_processor=[logits_processor],
|
||||
logits_processors=[logits_processor],
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
|
@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("こんにちは!私の名前はAI")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
|
@ -25,7 +25,7 @@ pip install mlx-whisper
|
||||
|
||||
At its simplest:
|
||||
|
||||
```
|
||||
```sh
|
||||
mlx_whisper audio_file.mp3
|
||||
```
|
||||
|
||||
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
|
||||
are many other supported command line options. To see them all, run
|
||||
`mlx_whisper -h`.
|
||||
|
||||
You can also pipe the audio content of other programs via stdin:
|
||||
|
||||
```sh
|
||||
some-process | mlx_whisper -
|
||||
```
|
||||
|
||||
The default output file name will be `content.*`. You can specify the name with
|
||||
the `--output-name` flag.
|
||||
|
||||
#### API
|
||||
|
||||
Transcribe audio with:
|
||||
@ -103,7 +112,7 @@ python convert.py --help
|
||||
```
|
||||
|
||||
By default, the conversion script will make the directory `mlx_models`
|
||||
and save the converted `weights.npz` and `config.json` there.
|
||||
and save the converted `weights.npz` and `config.json` there.
|
||||
|
||||
Each time it is run, `convert.py` will overwrite any model in the provided
|
||||
path. To save different models, make sure to set `--mlx-path` to a unique
|
||||
|
@ -3,7 +3,7 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from subprocess import CalledProcessError, run
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
if from_stdin:
|
||||
cmd = ["ffmpeg", "-i", "pipe:0"]
|
||||
else:
|
||||
cmd = ["ffmpeg", "-nostdin", "-i", file]
|
||||
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
cmd.extend([
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
])
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from . import audio
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
from .transcribe import transcribe
|
||||
from .writers import get_writer
|
||||
@ -27,15 +29,24 @@ def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
|
||||
)
|
||||
|
||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/whisper-tiny",
|
||||
type=str,
|
||||
help="The model directory or hugging face repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of transcription/translation output files before "
|
||||
"--output-format extensions"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
@ -200,6 +211,7 @@ def main():
|
||||
path_or_hf_repo: str = args.pop("model")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
output_name: str = args.pop("output_name")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@ -219,17 +231,25 @@ def main():
|
||||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||
for audio_path in args.pop("audio"):
|
||||
|
||||
for audio_obj in args.pop("audio"):
|
||||
if audio_obj == "-":
|
||||
# receive the contents from stdin rather than read a file
|
||||
audio_obj = audio.load_audio(from_stdin=True)
|
||||
|
||||
output_name = output_name or "content"
|
||||
else:
|
||||
output_name = output_name or pathlib.Path(audio_obj).stem
|
||||
try:
|
||||
result = transcribe(
|
||||
audio_path,
|
||||
audio_obj,
|
||||
path_or_hf_repo=path_or_hf_repo,
|
||||
**args,
|
||||
)
|
||||
writer(result, audio_path, **writer_args)
|
||||
writer(result, output_name, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,8 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
|
||||
@ -43,15 +41,13 @@ class ResultWriter:
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
|
||||
f".{self.extension}"
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
with output_path.open("wt", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(
|
||||
|
Loading…
Reference in New Issue
Block a user