mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
1 Commits
distribute
...
flux-dist-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9eff0d744 |
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
|
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
||||||
@@ -45,7 +45,7 @@ Some more useful examples are listed below.
|
|||||||
|
|
||||||
### Hugging Face
|
### Hugging Face
|
||||||
|
|
||||||
You can directly use or download converted checkpoints from the [MLX
|
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
We encourage you to join the community and [contribute new
|
We encourage you to join the community and [contribute new
|
||||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
|||||||
@@ -261,19 +261,23 @@ if __name__ == "__main__":
|
|||||||
generate_progress_images(0, flux, args)
|
generate_progress_images(0, flux, args)
|
||||||
|
|
||||||
grads = None
|
grads = None
|
||||||
losses = []
|
batch_cnt = 0
|
||||||
|
total_loss = 0
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
mx.eval(loss, grads, state)
|
total_loss = total_loss + loss
|
||||||
losses.append(loss.item())
|
batch_cnt += 1
|
||||||
|
mx.eval(total_loss, grads, state)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
|
||||||
|
total_loss = total_loss.item() / batch_cnt
|
||||||
print(
|
print(
|
||||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {total_loss:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {batch_cnt / (toc - tic):.3f} "
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
f"Peak mem: {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
@@ -285,7 +289,8 @@ if __name__ == "__main__":
|
|||||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
losses = []
|
total_loss = 0
|
||||||
|
batch_cnt = 0
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
save_adapters("final_adapters.safetensors", flux, args)
|
save_adapters("final_adapters.safetensors", flux, args)
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ mlx_lm.convert \
|
|||||||
```
|
```
|
||||||
|
|
||||||
Models can also be converted and quantized directly in the
|
Models can also be converted and quantized directly in the
|
||||||
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||||
Face Space.
|
Face Space.
|
||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
|||||||
@@ -16,25 +16,6 @@ DEFAULT_MAX_TOKENS = 256
|
|||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
def share_message(world, prompt):
|
|
||||||
if world.size() == 1:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
if world.rank() == 0:
|
|
||||||
size = mx.array([len(prompt)])
|
|
||||||
else:
|
|
||||||
size = mx.array([0])
|
|
||||||
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
|
|
||||||
if size == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if world.rank() == 0:
|
|
||||||
prompt = mx.array(prompt)
|
|
||||||
else:
|
|
||||||
prompt = mx.array([0] * len(prompt))
|
|
||||||
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||||
@@ -73,7 +54,6 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
world = mx.distributed.init()
|
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -83,30 +63,16 @@ def main():
|
|||||||
args.model,
|
args.model,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
sequential_load=mx.distributed.init().size() > 1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||||
print(
|
|
||||||
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
world.barrier()
|
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
if world.rank() == 0:
|
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
if query == "q":
|
if query == "q":
|
||||||
prompt = []
|
|
||||||
else:
|
|
||||||
messages = [{"role": "user", "content": query}]
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = share_message(world, prompt)
|
|
||||||
if len(prompt) == 0:
|
|
||||||
break
|
break
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
for response in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -115,9 +81,7 @@ def main():
|
|||||||
sampler=make_sampler(args.temp, args.top_p),
|
sampler=make_sampler(args.temp, args.top_p),
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
if world.rank() == 0:
|
print(response.text, flush=True, end="")
|
||||||
print(response, flush=True, end="")
|
|
||||||
if world.rank() == 0:
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,10 @@
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx.launch \
|
/path/to/mpirun \
|
||||||
|
-np 2 \
|
||||||
--hostfile /path/to/hosts.txt \
|
--hostfile /path/to/hosts.txt \
|
||||||
--backend mpi \
|
python /path/to/pipeline_generate.py --prompt "hello world"
|
||||||
/path/to/pipeline_generate.py \
|
|
||||||
--prompt "hello world"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||||
@@ -18,110 +17,59 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from mlx.utils import tree_flatten
|
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
from mlx_lm.utils import load_model, load_tokenizer
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
def download(repo: str, allow_patterns: list[str]) -> Path:
|
parser.add_argument(
|
||||||
return Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def shard_and_load(repo):
|
|
||||||
# Get model path with everything but weight safetensors
|
|
||||||
model_path = download(
|
|
||||||
args.model,
|
|
||||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Lazy load and shard model to figure out
|
|
||||||
# which weights we need
|
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
|
||||||
rank = group.rank()
|
|
||||||
model.model.pipeline(group)
|
|
||||||
|
|
||||||
# Figure out which files we need for the local shard
|
|
||||||
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
|
||||||
weight_index = json.load(fid)["weight_map"]
|
|
||||||
|
|
||||||
local_files = set()
|
|
||||||
for k, _ in tree_flatten(model.parameters()):
|
|
||||||
local_files.add(weight_index[k])
|
|
||||||
|
|
||||||
# Download weights for local shard
|
|
||||||
download(args.model, allow_patterns=local_files)
|
|
||||||
|
|
||||||
# Load and shard the model, and load the weights
|
|
||||||
tokenizer = load_tokenizer(model_path)
|
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
||||||
model.model.pipeline(group)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
|
||||||
# model for the first time.
|
|
||||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
default="mlx-community/DeepSeek-R1-3bit",
|
|
||||||
help="HF repo or path to local model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
default="Write a quicksort in C++.",
|
default="Write a quicksort in C++.",
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
||||||
rank = group.rank()
|
|
||||||
|
|
||||||
def rprint(*args, **kwargs):
|
model, tokenizer = load(model_repo, lazy=True)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
rank = group.rank()
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
|
# model for the first time.
|
||||||
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||||
|
|
||||||
|
|
||||||
|
def rprint(*args, **kwargs):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
|
|
||||||
model, tokenizer = shard_and_load(args.model)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
for response in stream_generate(
|
|
||||||
model, tokenizer, prompt, max_tokens=args.max_tokens
|
|
||||||
):
|
|
||||||
rprint(response.text, end="", flush=True)
|
rprint(response.text, end="", flush=True)
|
||||||
|
|
||||||
rprint()
|
rprint()
|
||||||
rprint("=" * 10)
|
rprint("=" * 10)
|
||||||
rprint(
|
rprint(
|
||||||
f"Prompt: {response.prompt_tokens} tokens, "
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(
|
rprint(
|
||||||
f"Generation: {response.generation_tokens} tokens, "
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
|||||||
@@ -191,7 +191,6 @@ def main():
|
|||||||
model_path,
|
model_path,
|
||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
sequential_load=mx.distributed.init().size() > 1,
|
|
||||||
)
|
)
|
||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
@@ -235,17 +234,13 @@ def main():
|
|||||||
else:
|
else:
|
||||||
draft_model = None
|
draft_model = None
|
||||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||||
|
|
||||||
world = mx.distributed.init()
|
|
||||||
print(f"Node {world.rank()} of {world.size()}", flush=True)
|
|
||||||
world.barrier()
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
|
verbose=args.verbose,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
verbose=args.verbose and world.rank() == 0,
|
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
@@ -254,10 +249,8 @@ def main():
|
|||||||
draft_model=draft_model,
|
draft_model=draft_model,
|
||||||
num_draft_tokens=args.num_draft_tokens,
|
num_draft_tokens=args.num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
if not args.verbose:
|
||||||
if not args.verbose and mx.distributed.init().rank() == 0:
|
|
||||||
print(response)
|
print(response)
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@@ -136,7 +135,6 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@@ -158,7 +156,6 @@ def build_parser():
|
|||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -364,30 +364,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(config, idx)
|
DeepseekV2DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.start_idx = 0
|
|
||||||
self.end_idx = len(self.layers)
|
|
||||||
self.num_layers = self.end_idx
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.pipeline_rank = 0
|
|
||||||
self.pipeline_size = 1
|
|
||||||
|
|
||||||
def pipeline(self, group):
|
|
||||||
# Split layers in reverse so rank=0 gets the last layers and
|
|
||||||
# rank=pipeline_size-1 gets the first
|
|
||||||
self.pipeline_rank = group.rank()
|
|
||||||
self.pipeline_size = group.size()
|
|
||||||
layers_per_rank = (
|
|
||||||
len(self.layers) + self.pipeline_size - 1
|
|
||||||
) // self.pipeline_size
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
|
||||||
self.num_layers = layers_per_rank
|
|
||||||
self.layers = self.layers[: self.end_idx]
|
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
||||||
self.num_layers = len(self.layers) - self.start_idx
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
@@ -396,31 +374,14 @@ class DeepseekV2Model(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
pipeline_rank = self.pipeline_rank
|
|
||||||
pipeline_size = self.pipeline_size
|
|
||||||
# Hack to avoid time-outs during prompt-processing
|
|
||||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * self.num_layers
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
for layer, c in zip(self.layers, cache):
|
||||||
if pipeline_rank < pipeline_size - 1:
|
h = layer(h, mask, c)
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
|
||||||
if pipeline_rank != 0:
|
|
||||||
h = mx.distributed.send(
|
|
||||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast h while keeping it in the graph
|
|
||||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@@ -457,4 +418,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
return self.model.layers
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -126,12 +125,6 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# A clipped silu to prevent fp16 from overflowing
|
|
||||||
@partial(mx.compile, shapeless=True)
|
|
||||||
def clipped_silu(x):
|
|
||||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Attention(nn.Module):
|
class DeepseekV3Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -319,10 +312,7 @@ class DeepseekV3MoE(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.num_experts_per_tok = config.num_experts_per_tok
|
self.num_experts_per_tok = config.num_experts_per_tok
|
||||||
self.switch_mlp = SwitchGLU(
|
self.switch_mlp = SwitchGLU(
|
||||||
config.hidden_size,
|
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||||
config.moe_intermediate_size,
|
|
||||||
config.n_routed_experts,
|
|
||||||
activation=clipped_silu,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(config)
|
self.gate = MoEGate(config)
|
||||||
@@ -369,7 +359,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
return h + r
|
out = h + r
|
||||||
|
# Protect against overflow for fp16
|
||||||
|
if out.dtype == mx.float16:
|
||||||
|
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Model(nn.Module):
|
class DeepseekV3Model(nn.Module):
|
||||||
@@ -381,10 +375,6 @@ class DeepseekV3Model(nn.Module):
|
|||||||
DeepseekV3DecoderLayer(config, idx)
|
DeepseekV3DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.start_idx = 0
|
|
||||||
self.end_idx = len(self.layers)
|
|
||||||
self.num_layers = self.end_idx
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pipeline_rank = 0
|
self.pipeline_rank = 0
|
||||||
self.pipeline_size = 1
|
self.pipeline_size = 1
|
||||||
@@ -397,11 +387,8 @@ class DeepseekV3Model(nn.Module):
|
|||||||
layers_per_rank = (
|
layers_per_rank = (
|
||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.layers = self.layers[start : start + layers_per_rank]
|
||||||
self.layers = self.layers[: self.end_idx]
|
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
||||||
self.num_layers = len(self.layers) - self.start_idx
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -413,30 +400,25 @@ class DeepseekV3Model(nn.Module):
|
|||||||
|
|
||||||
pipeline_rank = self.pipeline_rank
|
pipeline_rank = self.pipeline_rank
|
||||||
pipeline_size = self.pipeline_size
|
pipeline_size = self.pipeline_size
|
||||||
# Hack to avoid time-outs during prompt-processing
|
|
||||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * self.num_layers
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
h = mx.distributed.send(
|
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
|
||||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast h while keeping it in the graph
|
# Broadcast h while keeping it in the graph
|
||||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
h = mx.distributed.all_gather(h)[: h.shape[0]]
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@@ -475,4 +457,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
return self.model.layers
|
||||||
|
|||||||
@@ -1,185 +0,0 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArgs(BaseModelArgs):
|
|
||||||
hidden_size: int
|
|
||||||
num_hidden_layers: int
|
|
||||||
intermediate_size: int
|
|
||||||
num_attention_heads: int
|
|
||||||
num_key_value_heads: int
|
|
||||||
rms_norm_eps: float
|
|
||||||
vocab_size: int
|
|
||||||
attention_bias: bool
|
|
||||||
head_dim: int
|
|
||||||
max_position_embeddings: int
|
|
||||||
mlp_bias: bool
|
|
||||||
model_type: str
|
|
||||||
rope_theta: float
|
|
||||||
tie_word_embeddings: bool
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumAttention(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
dim = args.hidden_size
|
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
|
||||||
assert args.num_key_value_heads is not None
|
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
|
||||||
self.scale = head_dim**-0.5
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
||||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
B, L, D = x.shape
|
|
||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
queries = self.rope(queries, offset=cache.offset)
|
|
||||||
keys = self.rope(keys, offset=cache.offset)
|
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
output = scaled_dot_product_attention(
|
|
||||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
|
||||||
)
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
return self.o_proj(output)
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumMLP(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
self.intermediate_size = args.intermediate_size
|
|
||||||
|
|
||||||
self.gate_proj = nn.Linear(
|
|
||||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
self.up_proj = nn.Linear(
|
|
||||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
self.down_proj = nn.Linear(
|
|
||||||
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
|
|
||||||
self.self_attn = HeliumAttention(args)
|
|
||||||
self.mlp = HeliumMLP(args)
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
||||||
h = x + r
|
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
|
||||||
out = h + r
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class HeliumModel(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.num_hidden_layers = args.num_hidden_layers
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
|
|
||||||
assert self.vocab_size > 0
|
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
||||||
|
|
||||||
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
) -> mx.array:
|
|
||||||
h = self.embed_tokens(inputs)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = create_attention_mask(h, cache)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
cache = [None] * len(self.layers)
|
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
|
||||||
h = layer(h, mask, c)
|
|
||||||
|
|
||||||
return self.norm(h)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.model_type = args.model_type
|
|
||||||
|
|
||||||
self.model = HeliumModel(args)
|
|
||||||
|
|
||||||
self.vocab_size = args.vocab_size
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
if not args.tie_word_embeddings:
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
) -> mx.array:
|
|
||||||
out = self.model(inputs, mask, cache)
|
|
||||||
if self.args.tie_word_embeddings:
|
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
|
||||||
else:
|
|
||||||
out = self.lm_head(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layers(self):
|
|
||||||
return self.model.layers
|
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArgs(BaseModelArgs):
|
|
||||||
model_type: str
|
|
||||||
hidden_size: int
|
|
||||||
num_hidden_layers: int
|
|
||||||
intermediate_size: int
|
|
||||||
num_attention_heads: int
|
|
||||||
rms_norm_eps: float
|
|
||||||
vocab_size: int
|
|
||||||
bias: bool = False
|
|
||||||
qkv_bias: bool = False
|
|
||||||
max_position_embeddings: int = 32768
|
|
||||||
num_key_value_heads: int = None
|
|
||||||
rope_theta: float = 10000
|
|
||||||
rope_traditional: bool = False
|
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
||||||
tie_word_embeddings: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.num_key_value_heads is None:
|
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
|
||||||
|
|
||||||
if self.rope_scaling:
|
|
||||||
required_keys = {"factor", "rope_type"}
|
|
||||||
if not all(key in self.rope_scaling for key in required_keys):
|
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
||||||
|
|
||||||
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicNTKScalingRoPE(nn.Module):
|
|
||||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
max_position_embeddings: int = 2048,
|
|
||||||
traditional: bool = False,
|
|
||||||
base: float = 10000,
|
|
||||||
scale: float = 1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.original_base = base
|
|
||||||
self.dims = dims
|
|
||||||
self.traditional = traditional
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
seq_len = x.shape[1] + offset
|
|
||||||
if seq_len > self.max_position_embeddings:
|
|
||||||
base = self.original_base * (
|
|
||||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
|
||||||
) ** (self.dims / (self.dims - 2))
|
|
||||||
else:
|
|
||||||
base = self.original_base
|
|
||||||
|
|
||||||
return mx.fast.rope(
|
|
||||||
x,
|
|
||||||
self.dims,
|
|
||||||
traditional=self.traditional,
|
|
||||||
base=base,
|
|
||||||
scale=self.scale,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
dim = args.hidden_size
|
|
||||||
qkv_bias = args.qkv_bias
|
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
||||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
|
||||||
|
|
||||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
|
||||||
self.scale = head_dim**-0.5
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
|
|
||||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
|
||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
|
|
||||||
|
|
||||||
rope_scale = (
|
|
||||||
1 / args.rope_scaling["factor"]
|
|
||||||
if args.rope_scaling is not None
|
|
||||||
and args.rope_scaling["rope_type"] == "linear"
|
|
||||||
else 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
self.rope = DynamicNTKScalingRoPE(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=args.max_position_embeddings,
|
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
|
||||||
scale=rope_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
B, L, D = x.shape
|
|
||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
queries = self.rope(queries, offset=cache.offset)
|
|
||||||
keys = self.rope(keys, offset=cache.offset)
|
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
output = scaled_dot_product_attention(
|
|
||||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
|
||||||
)
|
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
return self.o_proj(output)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, dim, hidden_dim, bias):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
|
||||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
|
||||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
|
||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
|
||||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = Attention(args)
|
|
||||||
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
|
||||||
args.hidden_size, eps=args.rms_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Any] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
|
||||||
h = x + r
|
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
|
||||||
out = h + r
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class InternLM2Model(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
assert args.vocab_size > 0
|
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
||||||
self.layers = [
|
|
||||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
|
||||||
]
|
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
):
|
|
||||||
h = self.embed_tokens(inputs)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = create_attention_mask(h, cache)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
cache = [None] * len(self.layers)
|
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
|
||||||
h = layer(h, mask, cache=c)
|
|
||||||
|
|
||||||
return self.norm(h)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self, args: ModelArgs):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.model_type = args.model_type
|
|
||||||
self.model = InternLM2Model(args)
|
|
||||||
if not args.tie_word_embeddings:
|
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
inputs: mx.array,
|
|
||||||
mask: mx.array = None,
|
|
||||||
cache=None,
|
|
||||||
):
|
|
||||||
out = self.model(inputs, mask, cache)
|
|
||||||
if self.args.tie_word_embeddings:
|
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
|
||||||
else:
|
|
||||||
out = self.lm_head(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
# Remove unused precomputed rotary freqs
|
|
||||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layers(self):
|
|
||||||
return self.model.layers
|
|
||||||
@@ -200,36 +200,6 @@ class Model(nn.Module):
|
|||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
|
||||||
group = group or mx.distributed.init()
|
|
||||||
|
|
||||||
def all_to_sharded(l):
|
|
||||||
if isinstance(l, nn.QuantizedLinear):
|
|
||||||
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
|
||||||
else:
|
|
||||||
return nn.AllToShardedLinear.from_linear(l, group)
|
|
||||||
|
|
||||||
def sharded_to_all(l):
|
|
||||||
if isinstance(l, nn.QuantizedLinear):
|
|
||||||
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
|
||||||
else:
|
|
||||||
return nn.ShardedToAllLinear.from_linear(l, group)
|
|
||||||
|
|
||||||
N = group.size()
|
|
||||||
for layer in self.model.layers:
|
|
||||||
# Shard the self attention
|
|
||||||
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
|
||||||
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
|
||||||
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
|
||||||
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
|
||||||
layer.self_attn.n_heads //= N
|
|
||||||
layer.self_attn.n_kv_heads //= N
|
|
||||||
|
|
||||||
# Shard the MLP
|
|
||||||
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
|
||||||
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
|
||||||
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2024-2025 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -123,16 +123,17 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, A, state=None):
|
def ssm_step(self, x, state=None):
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = map(
|
delta, B, C = mx.split(
|
||||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
|
||||||
mx.split(
|
|
||||||
deltaBC,
|
deltaBC,
|
||||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
indices_or_sections=[
|
||||||
|
self.time_step_rank,
|
||||||
|
self.time_step_rank + self.ssm_state_size,
|
||||||
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
@@ -144,40 +145,25 @@ class MambaBlock(nn.Module):
|
|||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def _process_sequence(self, x, conv_cache, state_cache):
|
def __call__(self, x, cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
xz = self.in_proj(x)
|
if cache is None:
|
||||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
cache = [None, None]
|
||||||
|
|
||||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
|
||||||
x = nn.silu(conv_out)
|
|
||||||
|
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
current_state = state_cache
|
|
||||||
y = []
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
xt = x[:, t, :]
|
||||||
y.append(y_t)
|
xz = self.in_proj(xt)
|
||||||
y = mx.stack(y, axis=1)
|
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||||
z = self.out_proj(nn.silu(z) * y)
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||||
return z, (new_conv_cache, current_state)
|
x_t = conv_out.squeeze(1)
|
||||||
|
x_t = nn.silu(x_t)
|
||||||
def __call__(self, x, cache):
|
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
||||||
if cache is None:
|
z_t = nn.silu(z_t)
|
||||||
conv_cache, state_cache = None, None
|
output_t = y_t * z_t
|
||||||
else:
|
output_t = self.out_proj(output_t)
|
||||||
conv_cache, state_cache = cache[0], cache[1]
|
outputs.append(output_t)
|
||||||
|
output = mx.stack(outputs, axis=1)
|
||||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
|
||||||
x, conv_cache, state_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cache, MambaCache):
|
|
||||||
cache[0] = new_conv_cache
|
|
||||||
cache[1] = new_state_cache
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2025 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -147,11 +147,11 @@ def min_p_sampling(
|
|||||||
logprobs = logprobs * (1 / temperature)
|
logprobs = logprobs * (1 / temperature)
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
sorted_logprobs = logprobs[..., sorted_indices]
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_logprobs = sorted_logprobs[:, 0:1]
|
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = top_logprobs + math.log(min_p)
|
scaled_min_p = top_logprobs + math.log(min_p)
|
||||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
|||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||||
|
|
||||||
# Return sampled tokens
|
# Return sampled token
|
||||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
sorted_token = mx.random.categorical(selected_logprobs)
|
||||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
return sorted_indices[sorted_token]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
|
|
||||||
# sort probs in ascending order
|
# sort probs in ascending order
|
||||||
sorted_indices = mx.argsort(probs, axis=-1)
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||||
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
@@ -196,8 +196,10 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
token = sorted_indices.squeeze(0)[sorted_token]
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
|||||||
return prompt.rstrip()
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
def process_message_content(messages):
|
|
||||||
"""
|
|
||||||
Convert message content to a format suitable for `apply_chat_template`.
|
|
||||||
|
|
||||||
The function operates on messages in place. It converts the 'content' field
|
|
||||||
to a string instead of a list of text fragments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_list (list): A list of dictionaries, where each dictionary may
|
|
||||||
have a 'content' key containing a list of dictionaries with 'type' and
|
|
||||||
'text' keys.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
|
||||||
|
|
||||||
"""
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
if isinstance(content, list):
|
|
||||||
text_fragments = [
|
|
||||||
fragment["text"] for fragment in content if fragment["type"] == "text"
|
|
||||||
]
|
|
||||||
if len(text_fragments) != len(content):
|
|
||||||
raise ValueError("Only 'text' content type is supported.")
|
|
||||||
message["content"] = "".join(text_fragments)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptCache:
|
class PromptCache:
|
||||||
cache: List[Any] = field(default_factory=list)
|
cache: List[Any] = field(default_factory=list)
|
||||||
@@ -618,10 +591,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||||
if self.tokenizer.chat_template:
|
if self.tokenizer.chat_template:
|
||||||
messages = body["messages"]
|
|
||||||
process_message_content(messages)
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
body["messages"],
|
||||||
body.get("tools", None),
|
body.get("tools", None),
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
if prompt_feature and completion_feature:
|
if prompt_feature and completion_feature:
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||||
elif text_feature:
|
elif text_feature:
|
||||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Specify either a prompt and completion feature or a text "
|
"Specify either a prompt and completion feature or a text "
|
||||||
|
|||||||
@@ -140,8 +140,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = mx.array(0.0)
|
all_losses = 0
|
||||||
ntokens = mx.array(0)
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
@@ -159,8 +159,8 @@ def evaluate(
|
|||||||
ntokens += toks
|
ntokens += toks
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
all_losses = mx.distributed.all_sum(all_losses)
|
||||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
ntokens = mx.distributed.all_sum(ntokens)
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
@@ -272,9 +272,9 @@ def train(
|
|||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
|
|||||||
@@ -94,14 +94,12 @@ def linear_to_lora_layers(
|
|||||||
"phimoe",
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"helium",
|
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere2",
|
"cohere2",
|
||||||
"minicpm",
|
"minicpm",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"olmo2",
|
"olmo2",
|
||||||
"internlm3",
|
|
||||||
]:
|
]:
|
||||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||||
if model.model_type in ["mixtral", "phimoe"]:
|
if model.model_type in ["mixtral", "phimoe"]:
|
||||||
|
|||||||
@@ -306,12 +306,12 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if n != max_tokens:
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||||
@@ -398,9 +398,8 @@ def speculative_generate_step(
|
|||||||
quantize_cache_fn(cache)
|
quantize_cache_fn(cache)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
logprobs = logprobs.squeeze(0)
|
y = sampler(logprobs).squeeze(0)
|
||||||
y = sampler(logprobs)
|
return y, logprobs.squeeze(0)
|
||||||
return y, logprobs
|
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
@@ -627,8 +626,6 @@ def load_config(model_path: Path) -> dict:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
strict: bool = True,
|
|
||||||
sequential_load: bool = False,
|
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@@ -640,8 +637,6 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
strict (bool): Whether or not to raise an exception if weights don't
|
|
||||||
match. Default: ``True``
|
|
||||||
model_config (dict, optional): Optional configuration parameters for the
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
model. Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
@@ -664,7 +659,7 @@ def load_model(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files and strict:
|
if not weight_files:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
@@ -698,18 +693,9 @@ def load_model(
|
|||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
if mx.distributed.init().size() > 1:
|
|
||||||
if not hasattr(model, "shard"):
|
|
||||||
raise RuntimeError("Model doesn't support distributed inference.")
|
|
||||||
model.shard()
|
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
weights.clear()
|
|
||||||
if sequential_load:
|
|
||||||
for layer in model.layers:
|
|
||||||
mx.eval(layer.parameters())
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -722,7 +708,6 @@ def load(
|
|||||||
model_config={},
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
sequential_load: bool = False,
|
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
@@ -738,8 +723,6 @@ def load(
|
|||||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
sequential_load (bool): If True then load each layer sequentially to
|
|
||||||
ensure that we are not wasting memory.
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||||
|
|
||||||
@@ -749,7 +732,7 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
|
model, config = load_model(model_path, lazy)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -763,7 +746,7 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model, config = load_model(model_path, lazy=lazy)
|
model, config = load_model(model_path, lazy)
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def swapped_with_identity(obj, func):
|
def swapped_with_identity(obj, func):
|
||||||
old_func = getattr(obj, func)
|
old_func = getattr(obj, func)
|
||||||
setattr(obj, func, lambda x, **kwargs: x)
|
setattr(obj, func, lambda x: x)
|
||||||
yield
|
yield
|
||||||
setattr(obj, func, old_func)
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
|||||||
@@ -927,23 +927,6 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_internlm3(self):
|
|
||||||
from mlx_lm.models import internlm3
|
|
||||||
|
|
||||||
args = internlm3.ModelArgs(
|
|
||||||
model_type="internlm3",
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=4,
|
|
||||||
intermediate_size=2048,
|
|
||||||
num_attention_heads=4,
|
|
||||||
rms_norm_eps=1e-5,
|
|
||||||
vocab_size=10_000,
|
|
||||||
)
|
|
||||||
model = internlm3.Model(args)
|
|
||||||
self.model_test_runner(
|
|
||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -28,12 +28,6 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||||
self.assertTrue(token in (1, 2, 3))
|
self.assertTrue(token in (1, 2, 3))
|
||||||
|
|
||||||
# Batch mode works
|
|
||||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
|
||||||
logits = mx.log(probs)
|
|
||||||
tokens = top_p_sampling(logits, 0.5, temperature)
|
|
||||||
self.assertEqual(tokens.tolist(), [0, 1])
|
|
||||||
|
|
||||||
def test_min_p_sampling(self):
|
def test_min_p_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
@@ -48,12 +42,6 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = min_p_sampling(logits, 0.05)
|
token = min_p_sampling(logits, 0.05)
|
||||||
self.assertTrue(token in (0, 3))
|
self.assertTrue(token in (0, 3))
|
||||||
|
|
||||||
# Batch mode works
|
|
||||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
|
||||||
logits = mx.log(probs)
|
|
||||||
tokens = min_p_sampling(logits, 0.7)
|
|
||||||
self.assertEqual(tokens.tolist(), [0, 1])
|
|
||||||
|
|
||||||
def test_top_k_sampling(self):
|
def test_top_k_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
|
|||||||
@@ -80,29 +80,6 @@ class TestServer(unittest.TestCase):
|
|||||||
self.assertIn("id", response_body)
|
self.assertIn("id", response_body)
|
||||||
self.assertIn("choices", response_body)
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
def test_handle_chat_completions_with_content_fragments(self):
|
|
||||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
|
||||||
chat_post_data = {
|
|
||||||
"model": "chat_model",
|
|
||||||
"max_tokens": 10,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.85,
|
|
||||||
"repetition_penalty": 1.2,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "You are a helpful assistant."}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
response = requests.post(url, json=chat_post_data)
|
|
||||||
response_body = response.text
|
|
||||||
self.assertIn("id", response_body)
|
|
||||||
self.assertIn("choices", response_body)
|
|
||||||
|
|
||||||
def test_handle_models(self):
|
def test_handle_models(self):
|
||||||
url = f"http://localhost:{self.port}/v1/models"
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
|||||||
Reference in New Issue
Block a user