Remove async eval and add sequential load

This commit is contained in:
Angelos Katharopoulos 2024-11-05 13:04:07 -08:00 committed by Awni Hannun
parent a0ce0594f6
commit 026362e0f8
2 changed files with 17 additions and 4 deletions

View File

@ -191,6 +191,7 @@ def main():
model_path, model_path,
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
) )
for eos_token in args.extra_eos_token: for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token) tokenizer.add_eos_token(eos_token)
@ -234,13 +235,17 @@ def main():
else: else:
draft_model = None draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
sampler=sampler, sampler=sampler,
verbose=args.verbose and mx.distributed.init().rank() == 0, verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -306,12 +306,12 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n != max_tokens: if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.eval(next_y, next_logprobs)
if n == 0: if n == 0:
mx.eval(y) mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@ -628,6 +628,7 @@ def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True, strict: bool = True,
sequential_load: bool = False,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@ -705,6 +706,10 @@ def load_model(
model.shard() model.shard()
if not lazy: if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
@ -717,6 +722,7 @@ def load(
model_config={}, model_config={},
adapter_path: Optional[str] = None, adapter_path: Optional[str] = None,
lazy: bool = False, lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]: ) -> Tuple[nn.Module, TokenizerWrapper]:
""" """
Load the model and tokenizer from a given path or a huggingface repository. Load the model and tokenizer from a given path or a huggingface repository.
@ -732,6 +738,8 @@ def load(
lazy (bool): If ``False`` eval the model parameters to make sure they are lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns: Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
@ -741,7 +749,7 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model, config = load_model(model_path, lazy) model, config = load_model(model_path, sequential_load, lazy)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()