diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 245ad155..1b7ea521 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -191,6 +191,7 @@ def main(): model_path, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, + sequential_load=mx.distributed.init().size() > 1, ) for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) @@ -234,13 +235,17 @@ def main(): else: draft_model = None sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) + + world = mx.distributed.init() + print(f"Node {world.rank()} of {world.size()}", flush=True) + world.barrier() response = generate( model, tokenizer, prompt, max_tokens=args.max_tokens, sampler=sampler, - verbose=args.verbose and mx.distributed.init().rank() == 0, + verbose=args.verbose and world.rank() == 0, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 557c4316..b275ef3d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -306,12 +306,12 @@ def generate_step( y, logprobs = _step(y) - mx.async_eval(y, logprobs) + mx.eval(y, logprobs) n = 0 while True: if n != max_tokens: next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) + mx.eval(next_y, next_logprobs) if n == 0: mx.eval(y) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) @@ -628,6 +628,7 @@ def load_model( model_path: Path, lazy: bool = False, strict: bool = True, + sequential_load: bool = False, model_config: dict = {}, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> nn.Module: @@ -705,6 +706,10 @@ def load_model( model.shard() if not lazy: + weights.clear() + if sequential_load: + for layer in model.layers: + mx.eval(layer.parameters()) mx.eval(model.parameters()) model.eval() @@ -717,6 +722,7 @@ def load( model_config={}, adapter_path: Optional[str] = None, lazy: bool = False, + sequential_load: bool = False, ) -> Tuple[nn.Module, TokenizerWrapper]: """ Load the model and tokenizer from a given path or a huggingface repository. @@ -732,6 +738,8 @@ def load( lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` + sequential_load (bool): If True then load each layer sequentially to + ensure that we are not wasting memory. Returns: Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. @@ -741,7 +749,7 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model, config = load_model(model_path, lazy) + model, config = load_model(model_path, sequential_load, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval()