mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
starting
This commit is contained in:
parent
ed9e81dd58
commit
3783156072
@ -238,6 +238,8 @@ def main():
|
|||||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
|
sampler = make_sampler(
|
||||||
|
args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -245,8 +247,7 @@ def main():
|
|||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
sampler=sampler,
|
||||||
top_p=args.top_p,
|
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
@ -5,6 +5,63 @@ from functools import partial
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def make_sampler(
|
||||||
|
temp: float = 0.0,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
) -> Callable[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Make a sampler function for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
|
Default: ``0``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callabel[mx.array, mx.array]:
|
||||||
|
A sampler which takes log-probabilities and returns tokens.
|
||||||
|
"""
|
||||||
|
if temp == 0:
|
||||||
|
return lambda x: mx.argmax(x, axis=-1)
|
||||||
|
elif top_p > 0 and top_p < 1.0:
|
||||||
|
return lambda x: top_p_sampling(x, top_p, temp)
|
||||||
|
elif min_p != 0.0:
|
||||||
|
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||||
|
else:
|
||||||
|
return lambda x: categorical_sampling(x, temp)
|
||||||
|
|
||||||
|
|
||||||
|
def make_logits_processors():
|
||||||
|
"""
|
||||||
|
Make logits processors for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits_processors = []
|
||||||
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
|
||||||
|
def logit_bias_processor(_, logits):
|
||||||
|
logits[:, indices] += values
|
||||||
|
return logits
|
||||||
|
|
||||||
|
logits_processors.append(logit_bias_processor)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def min_p_sampling(
|
def min_p_sampling(
|
||||||
logits: mx.array,
|
logits: mx.array,
|
||||||
@ -100,3 +157,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def categorical_sampling(logits, temp):
|
def categorical_sampling(logits, temp):
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
|
||||||
|
def repetition_penalty(penalty: float, context_size: int = 20):
|
||||||
|
"""
|
||||||
|
Make repetition penalty processor.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1909.05858
|
||||||
|
|
||||||
|
Args:
|
||||||
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
context_size (int): The number of previous tokens to use.
|
||||||
|
Default: ``20``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[mx.array, List[int]], mx.array]:
|
||||||
|
The repetition penalty processor.
|
||||||
|
"""
|
||||||
|
if penalty < 0 or not isinstance(penalty, float):
|
||||||
|
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||||
|
|
||||||
|
def repetition_penalty_processor(logits, tokens):
|
||||||
|
if len(tokens) > 0:
|
||||||
|
tokens = tokens[-context_size:]
|
||||||
|
selected_logits = logits[:, tokens]
|
||||||
|
selected_logits = mx.where(
|
||||||
|
selected_logits < 0,
|
||||||
|
selected_logits * penalty,
|
||||||
|
selected_logits / penalty,
|
||||||
|
)
|
||||||
|
logits[:, tokens] = selected_logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
return repetition_penalty_processor
|
||||||
|
@ -461,12 +461,13 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
token_logprobs = []
|
token_logprobs = []
|
||||||
top_tokens = []
|
top_tokens = []
|
||||||
|
|
||||||
prompt = self.get_prompt_cache(prompt)
|
prompt = mx.array(self.get_prompt_cache(prompt))
|
||||||
|
|
||||||
|
tic = time.perf_counter()
|
||||||
for _, (token, logprobs) in zip(
|
for _, (token, logprobs) in zip(
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=mx.array(prompt),
|
prompt=prompt,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
@ -476,6 +477,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
prompt_cache=self.prompt_cache.cache,
|
prompt_cache=self.prompt_cache.cache,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
tic = time.perf_counter()
|
||||||
|
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
@ -507,6 +512,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
if stop_sequence_suffix is None
|
if stop_sequence_suffix is None
|
||||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||||
)
|
)
|
||||||
|
gen_time = time.perf_counter() - tic
|
||||||
|
prompt_tps = len(prompt) / prompt_time
|
||||||
|
gen_tps = len(tokens) / gen_time
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
response = self.generate_response(
|
response = self.generate_response(
|
||||||
text,
|
text,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
@ -517,6 +526,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
|
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
response_json = json.dumps(response).encode()
|
response_json = json.dumps(response).encode()
|
||||||
indent = "\t" # Backslashes can't be inside of f-strings
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||||
@ -552,12 +564,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_sequence_suffix = None
|
stop_sequence_suffix = None
|
||||||
logging.debug(f"Starting stream:")
|
logging.debug(f"Starting stream:")
|
||||||
|
|
||||||
prompt = self.get_prompt_cache(prompt)
|
prompt = mx.array(self.get_prompt_cache(prompt))
|
||||||
|
|
||||||
for _, (token, _) in zip(
|
for _, (token, _) in zip(
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=mx.array(prompt),
|
prompt=prompt,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
|
@ -34,6 +34,9 @@ MODEL_REMAPPING = {
|
|||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
|
# A stream on the default device just for generation
|
||||||
|
generation_stream = mx.new_stream(mx.default_device())
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundError(Exception):
|
class ModelNotFoundError(Exception):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
|
||||||
"""
|
|
||||||
Apply repetition penalty to specific logits based on the given context.
|
|
||||||
|
|
||||||
Paper: https://arxiv.org/abs/1909.05858
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (mx.array): The logits produced by the language model.
|
|
||||||
tokens (mx.array): A list of N previous tokens.
|
|
||||||
penalty (float): The repetition penalty factor to be applied.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
|
||||||
"""
|
|
||||||
if len(tokens) > 0:
|
|
||||||
selected_logits = logits[:, tokens]
|
|
||||||
selected_logits = mx.where(
|
|
||||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
|
||||||
)
|
|
||||||
logits[:, tokens] = selected_logits
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||||
if (
|
if (
|
||||||
kv_bits is not None
|
kv_bits is not None
|
||||||
@ -175,17 +155,11 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
repetition_context_size: Optional[int] = 20,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
min_p: float = 0.0,
|
|
||||||
min_tokens_to_keep: int = 1,
|
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
@ -196,25 +170,15 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
|
||||||
Default: ``0``.
|
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating
|
|
||||||
tokens.
|
|
||||||
repetition_context_size (int, optional): The number of tokens to
|
|
||||||
consider for repetition penalty. Default: ``20``.
|
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
|
||||||
more less likely words.
|
|
||||||
min_p (float, optional): The minimum value (scaled by the top token's
|
|
||||||
probability) that a token probability must have to be considered.
|
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
|
||||||
be filtered by min_p sampling.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place.
|
provided, the cache will be updated in place.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
sampler (Callable[mx.array, mx.array], optional). A function which
|
||||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
takes log probabilities and returns tokens. If ``None`` then the
|
||||||
|
argmax is used. Default: ``None``.
|
||||||
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
@ -228,49 +192,6 @@ def generate_step(
|
|||||||
one token and a vector of log probabilities.
|
one token and a vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
|
||||||
|
|
||||||
if temp == 0:
|
|
||||||
token = mx.argmax(logits, axis=-1)
|
|
||||||
else:
|
|
||||||
if top_p > 0 and top_p < 1.0:
|
|
||||||
token = top_p_sampling(logits, top_p, temp)
|
|
||||||
elif min_p != 0.0:
|
|
||||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
|
||||||
else:
|
|
||||||
token = categorical_sampling(logits, temp)
|
|
||||||
|
|
||||||
return token, logprobs
|
|
||||||
|
|
||||||
if repetition_penalty and (
|
|
||||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_processor = logits_processor or []
|
|
||||||
|
|
||||||
if repetition_penalty:
|
|
||||||
|
|
||||||
def repetition_penalty_processor(tokens, logits):
|
|
||||||
return apply_repetition_penalty(
|
|
||||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_processor.append(repetition_penalty_processor)
|
|
||||||
|
|
||||||
if logit_bias:
|
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
|
||||||
values = mx.array(list(logit_bias.values()))
|
|
||||||
|
|
||||||
def logit_bias_processor(_, logits):
|
|
||||||
logits[:, indices] += values
|
|
||||||
return logits
|
|
||||||
|
|
||||||
logits_processor.append(logit_bias_processor)
|
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
tokens = None
|
tokens = None
|
||||||
|
|
||||||
@ -283,24 +204,27 @@ def generate_step(
|
|||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
|
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
|
with mx.stream(generation_stream):
|
||||||
|
logits = model(y[None], cache=prompt_cache)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
logits = model(y[None], cache=prompt_cache)
|
if logits_processors:
|
||||||
logits = logits[:, -1, :]
|
nonlocal tokens
|
||||||
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||||
|
|
||||||
if logits_processor:
|
for processor in logits_processors:
|
||||||
nonlocal tokens
|
logits = processor(tokens, logits)
|
||||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
|
||||||
|
|
||||||
for processor in logits_processor:
|
maybe_quantize_kv_cache(
|
||||||
logits = processor(tokens, logits)
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
|
)
|
||||||
|
|
||||||
maybe_quantize_kv_cache(
|
logprobs = logits - mx.logsumexp(logits)
|
||||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
y = sampler(logprobs)
|
||||||
)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
y, logprobs = sample(logits)
|
|
||||||
return y, logprobs.squeeze(0)
|
|
||||||
|
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||||
@ -416,8 +340,7 @@ def generate(
|
|||||||
if formatter:
|
if formatter:
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
with mx.stream(mx.cpu):
|
prob = mx.exp(logprobs[token]).item()
|
||||||
prob = mx.exp(logprobs[token]).item()
|
|
||||||
formatter(detokenizer.last_segment, prob)
|
formatter(detokenizer.last_segment, prob)
|
||||||
else:
|
else:
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
@ -623,7 +546,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
f"""
|
f"""
|
||||||
# {upload_repo}
|
# {upload_repo}
|
||||||
|
|
||||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||||
|
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
||||||
|
using mlx-lm version **{__version__}**.
|
||||||
|
|
||||||
## Use with mlx
|
## Use with mlx
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user