mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Remove async eval and add sequential load
This commit is contained in:
parent
a0ce0594f6
commit
026362e0f8
@ -191,6 +191,7 @@ def main():
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
sequential_load=mx.distributed.init().size() > 1,
|
||||
)
|
||||
for eos_token in args.extra_eos_token:
|
||||
tokenizer.add_eos_token(eos_token)
|
||||
@ -234,13 +235,17 @@ def main():
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
|
||||
world = mx.distributed.init()
|
||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
||||
world.barrier()
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
sampler=sampler,
|
||||
verbose=args.verbose and mx.distributed.init().rank() == 0,
|
||||
verbose=args.verbose and world.rank() == 0,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
@ -306,12 +306,12 @@ def generate_step(
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y, logprobs)
|
||||
mx.eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
mx.eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
mx.eval(y)
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
@ -628,6 +628,7 @@ def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
strict: bool = True,
|
||||
sequential_load: bool = False,
|
||||
model_config: dict = {},
|
||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||
) -> nn.Module:
|
||||
@ -705,6 +706,10 @@ def load_model(
|
||||
model.shard()
|
||||
|
||||
if not lazy:
|
||||
weights.clear()
|
||||
if sequential_load:
|
||||
for layer in model.layers:
|
||||
mx.eval(layer.parameters())
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
@ -717,6 +722,7 @@ def load(
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
sequential_load: bool = False,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
@ -732,6 +738,8 @@ def load(
|
||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
sequential_load (bool): If True then load each layer sequentially to
|
||||
ensure that we are not wasting memory.
|
||||
Returns:
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
@ -741,7 +749,7 @@ def load(
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model, config = load_model(model_path, lazy)
|
||||
model, config = load_model(model_path, sequential_load, lazy)
|
||||
if adapter_path is not None:
|
||||
model = load_adapters(model, adapter_path)
|
||||
model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user