mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
starting
This commit is contained in:
parent
ed9e81dd58
commit
3783156072
@ -238,6 +238,8 @@ def main():
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
sampler = make_sampler(
|
||||
args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -245,8 +247,7 @@ def main():
|
||||
args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
@ -5,6 +5,63 @@ from functools import partial
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def make_sampler(
|
||||
temp: float = 0.0,
|
||||
top_p: float = 0.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Callable[mx.array, mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
|
||||
Returns:
|
||||
Callabel[mx.array, mx.array]:
|
||||
A sampler which takes log-probabilities and returns tokens.
|
||||
"""
|
||||
if temp == 0:
|
||||
return lambda x: mx.argmax(x, axis=-1)
|
||||
elif top_p > 0 and top_p < 1.0:
|
||||
return lambda x: top_p_sampling(x, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
return lambda x: categorical_sampling(x, temp)
|
||||
|
||||
|
||||
def make_logits_processors():
|
||||
"""
|
||||
Make logits processors for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
"""
|
||||
|
||||
logits_processors = []
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_processor)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
logits: mx.array,
|
||||
@ -100,3 +157,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def categorical_sampling(logits, temp):
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
|
||||
def repetition_penalty(penalty: float, context_size: int = 20):
|
||||
"""
|
||||
Make repetition penalty processor.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
context_size (int): The number of previous tokens to use.
|
||||
Default: ``20``.
|
||||
|
||||
Returns:
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
if penalty < 0 or not isinstance(penalty, float):
|
||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||
|
||||
def repetition_penalty_processor(logits, tokens):
|
||||
if len(tokens) > 0:
|
||||
tokens = tokens[-context_size:]
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0,
|
||||
selected_logits * penalty,
|
||||
selected_logits / penalty,
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
return repetition_penalty_processor
|
||||
|
@ -461,12 +461,13 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
prompt = mx.array(self.get_prompt_cache(prompt))
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _, (token, logprobs) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=mx.array(prompt),
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
@ -476,6 +477,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
@ -507,6 +512,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
if stop_sequence_suffix is None
|
||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||
)
|
||||
gen_time = time.perf_counter() - tic
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = len(tokens) / gen_time
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
response = self.generate_response(
|
||||
text,
|
||||
finish_reason,
|
||||
@ -517,6 +526,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||
response_json = json.dumps(response).encode()
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||
@ -552,12 +564,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
prompt = mx.array(self.get_prompt_cache(prompt))
|
||||
|
||||
for _, (token, _) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=mx.array(prompt),
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
|
@ -34,6 +34,9 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
# A stream on the default device just for generation
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
return model_path
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||
"""
|
||||
Apply repetition penalty to specific logits based on the given context.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
logits (mx.array): The logits produced by the language model.
|
||||
tokens (mx.array): A list of N previous tokens.
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
|
||||
Returns:
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(tokens) > 0:
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
kv_bits is not None
|
||||
@ -175,17 +155,11 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
@ -196,25 +170,15 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
sampler (Callable[mx.array, mx.array], optional). A function which
|
||||
takes log probabilities and returns tokens. If ``None`` then the
|
||||
argmax is used. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
@ -228,49 +192,6 @@ def generate_step(
|
||||
one token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
token = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
):
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
logits_processor = logits_processor or []
|
||||
|
||||
if repetition_penalty:
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
return apply_repetition_penalty(
|
||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||
)
|
||||
|
||||
logits_processor.append(repetition_penalty_processor)
|
||||
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processor.append(logit_bias_processor)
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
@ -283,23 +204,26 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processor:
|
||||
if logits_processors:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
|
||||
for processor in logits_processor:
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
@ -416,7 +340,6 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
@ -623,7 +546,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
f"""
|
||||
# {upload_repo}
|
||||
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
||||
using mlx-lm version **{__version__}**.
|
||||
|
||||
## Use with mlx
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user