mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
c26e188417
@ -1,10 +1,10 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 25.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.13.2
|
rev: 6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
|
@ -45,7 +45,7 @@ Some more useful examples are listed below.
|
|||||||
|
|
||||||
### Hugging Face
|
### Hugging Face
|
||||||
|
|
||||||
Note: You can now directly download a few converted checkpoints from the [MLX
|
You can directly use or download converted checkpoints from the [MLX
|
||||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
We encourage you to join the community and [contribute new
|
We encourage you to join the community and [contribute new
|
||||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
@ -164,7 +164,7 @@ mlx_lm.convert \
|
|||||||
```
|
```
|
||||||
|
|
||||||
Models can also be converted and quantized directly in the
|
Models can also be converted and quantized directly in the
|
||||||
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||||
Face Space.
|
Face Space.
|
||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
@ -101,6 +101,14 @@ You can specify the output location with `--adapter-path`.
|
|||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
|
#### Prompt Masking
|
||||||
|
|
||||||
|
The default training computes a loss for every token in the sample. You can
|
||||||
|
ignore the prompt and compute loss for just the completion by passing
|
||||||
|
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||||
|
datasets. For `chat` datasets the final message in the message list is
|
||||||
|
considered the completion. See the [dataset section](#Data) for more details.
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
To compute test set perplexity use:
|
To compute test set perplexity use:
|
||||||
@ -315,11 +323,27 @@ hf_dataset:
|
|||||||
|
|
||||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||||
dataset.
|
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||||
|
|
||||||
- To specify the train, valid, or test splits, set the corresponding
|
- To specify the train, valid, or test splits, set the corresponding
|
||||||
`{train,valid,test}_split` argument.
|
`{train,valid,test}_split` argument.
|
||||||
|
|
||||||
|
You can specify a list of Hugging Face datasets with a list of records each
|
||||||
|
with the same structure as above. For example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
hf_dataset:
|
||||||
|
- name: "Open-Orca/OpenOrca"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
prompt_feature: "question"
|
||||||
|
completion_feature: "response"
|
||||||
|
- name: "trl-lib/ultrafeedback_binarized"
|
||||||
|
train_split: "train[:90%]"
|
||||||
|
valid_split: "train[-10%:]"
|
||||||
|
chat_feature: "chosen"
|
||||||
|
```
|
||||||
|
|
||||||
- Arguments specified in `config` will be passed as keyword arguments to
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ def main():
|
|||||||
print("Saving...")
|
print("Saving...")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
@ -295,7 +295,9 @@ class MLXLM(LM):
|
|||||||
completions = []
|
completions = []
|
||||||
|
|
||||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||||
context = self._tokenize(context)
|
context = self.tokenizer.encode(
|
||||||
|
context, add_special_tokens=not self.use_chat_template
|
||||||
|
)
|
||||||
max_tokens = min(
|
max_tokens = min(
|
||||||
self._max_tokens,
|
self._max_tokens,
|
||||||
self.tokenizer.model_max_length - len(context),
|
self.tokenizer.model_max_length - len(context),
|
||||||
|
@ -23,7 +23,6 @@ response = generate(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
temp=0.0,
|
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,10 +4,11 @@
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
/path/to/mpirun \
|
mlx.launch \
|
||||||
-np 2 \
|
|
||||||
--hostfile /path/to/hosts.txt \
|
--hostfile /path/to/hosts.txt \
|
||||||
python /path/to/pipeline_generate.py --prompt "hello world"
|
--backend mpi \
|
||||||
|
/path/to/pipeline_generate.py \
|
||||||
|
--prompt "hello world"
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||||
@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
|
from mlx_lm.utils import load_model, load_tokenizer
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
default="mlx-community/DeepSeek-R1-3bit",
|
|
||||||
help="HF repo or path to local model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
"-p",
|
|
||||||
default="Write a quicksort in C++.",
|
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
"-m",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Maximum number of tokens to generate",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model, tokenizer = load(args.model, lazy=True)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
rank = group.rank()
|
|
||||||
model.model.pipeline(group)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
|
||||||
# model for the first time.
|
|
||||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
||||||
|
|
||||||
|
|
||||||
def rprint(*args, **kwargs):
|
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||||
if rank == 0:
|
return Path(
|
||||||
print(*args, **kwargs)
|
snapshot_download(
|
||||||
|
repo,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
def shard_and_load(repo):
|
||||||
rprint(response.text, end="", flush=True)
|
# Get model path with everything but weight safetensors
|
||||||
|
model_path = download(
|
||||||
|
args.model,
|
||||||
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||||
|
)
|
||||||
|
|
||||||
rprint()
|
# Lazy load and shard model to figure out
|
||||||
rprint("=" * 10)
|
# which weights we need
|
||||||
rprint(
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
f"Prompt: {response.prompt_tokens} tokens, "
|
|
||||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
group = mx.distributed.init(backend="mpi")
|
||||||
)
|
rank = group.rank()
|
||||||
rprint(
|
model.model.pipeline(group)
|
||||||
f"Generation: {response.generation_tokens} tokens, "
|
|
||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
# Figure out which files we need for the local shard
|
||||||
)
|
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
weight_index = json.load(fid)["weight_map"]
|
||||||
|
|
||||||
|
local_files = set()
|
||||||
|
for k, _ in tree_flatten(model.parameters()):
|
||||||
|
local_files.add(weight_index[k])
|
||||||
|
|
||||||
|
# Download weights for local shard
|
||||||
|
download(args.model, allow_patterns=local_files)
|
||||||
|
|
||||||
|
# Load and shard the model, and load the weights
|
||||||
|
tokenizer = load_tokenizer(model_path)
|
||||||
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
|
# model for the first time.
|
||||||
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx-community/DeepSeek-R1-3bit",
|
||||||
|
help="HF repo or path to local model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
"-p",
|
||||||
|
default="Write a quicksort in C++.",
|
||||||
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
group = mx.distributed.init(backend="mpi")
|
||||||
|
rank = group.rank()
|
||||||
|
|
||||||
|
def rprint(*args, **kwargs):
|
||||||
|
if rank == 0:
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
model, tokenizer = shard_and_load(args.model)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
for response in stream_generate(
|
||||||
|
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||||
|
):
|
||||||
|
rprint(response.text, end="", flush=True)
|
||||||
|
|
||||||
|
rprint()
|
||||||
|
rprint("=" * 10)
|
||||||
|
rprint(
|
||||||
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
rprint(
|
||||||
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
@ -93,6 +93,12 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the default chat template",
|
help="Use the default chat template",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template-config",
|
||||||
|
help="Additional config for `apply_chat_template`. Should be a dictionary of"
|
||||||
|
" string keys to values represented as a JSON decodable string.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -149,7 +155,6 @@ def setup_arg_parser():
|
|||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
# Load the prompt cache and metadata if a cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
@ -195,11 +200,15 @@ def main():
|
|||||||
for eos_token in args.extra_eos_token:
|
for eos_token in args.extra_eos_token:
|
||||||
tokenizer.add_eos_token(eos_token)
|
tokenizer.add_eos_token(eos_token)
|
||||||
|
|
||||||
|
template_kwargs = {}
|
||||||
|
if args.chat_template_config is not None:
|
||||||
|
template_kwargs = json.loads(args.chat_template_config)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
elif using_cache:
|
elif using_cache:
|
||||||
tokenizer.chat_template = metadata["chat_template"]
|
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
||||||
|
|
||||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||||
@ -209,8 +218,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
**template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||||
|
@ -94,6 +94,14 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Mask the prompt in the loss when training",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -219,6 +227,7 @@ def train_model(
|
|||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -282,12 +282,12 @@ class MoEGate(nn.Module):
|
|||||||
if self.topk_method == "group_limited_greedy":
|
if self.topk_method == "group_limited_greedy":
|
||||||
bsz, seq_len = x.shape[:2]
|
bsz, seq_len = x.shape[:2]
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||||
group_scores = scores.max(axis=-1)
|
group_scores = scores.max(axis=-1, keepdims=True)
|
||||||
k = self.n_group - self.topk_group
|
k = self.n_group - self.topk_group
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
scores = mx.put_along_axis(
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
)
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
scores = scores.reshape(bsz, seq_len, -1)
|
||||||
|
|
||||||
k = self.top_k
|
k = self.top_k
|
||||||
@ -364,8 +364,32 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(config, idx)
|
DeepseekV2DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.pipeline_rank = 0
|
||||||
|
self.pipeline_size = 1
|
||||||
|
|
||||||
|
def pipeline(self, group):
|
||||||
|
# Split layers in reverse so rank=0 gets the last layers and
|
||||||
|
# rank=pipeline_size-1 gets the first
|
||||||
|
self.pipeline_rank = group.rank()
|
||||||
|
self.pipeline_size = group.size()
|
||||||
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
|
if self.pipeline_rank < extra:
|
||||||
|
layers_per_rank += 1
|
||||||
|
|
||||||
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.num_layers = layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
@ -374,14 +398,31 @@ class DeepseekV2Model(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
|
pipeline_rank = self.pipeline_rank
|
||||||
|
pipeline_size = self.pipeline_size
|
||||||
|
# Hack to avoid time-outs during prompt-processing
|
||||||
|
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
# Receive from the previous process in the pipeline
|
||||||
h = layer(h, mask, c)
|
if pipeline_rank < pipeline_size - 1:
|
||||||
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
|
# Send to the next process in the pipeline
|
||||||
|
if pipeline_rank != 0:
|
||||||
|
h = mx.distributed.send(
|
||||||
|
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Broadcast h while keeping it in the graph
|
||||||
|
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@ -418,4 +459,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def group_expert_select(
|
||||||
|
gates,
|
||||||
|
e_score_correction_bias,
|
||||||
|
top_k,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
routed_scaling_factor,
|
||||||
|
norm_topk_prob,
|
||||||
|
):
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||||
|
scores = scores + e_score_correction_bias
|
||||||
|
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||||
|
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||||
|
k = n_group - topk_group
|
||||||
|
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||||
|
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||||
|
scores = mx.flatten(scores, -2, -1)
|
||||||
|
|
||||||
|
k = top_k
|
||||||
|
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||||
|
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||||
|
if top_k > 1 and norm_topk_prob:
|
||||||
|
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||||
|
scores = scores / denominator
|
||||||
|
scores = scores * routed_scaling_factor
|
||||||
|
|
||||||
|
return inds, scores
|
||||||
|
|
||||||
|
|
||||||
class MoEGate(nn.Module):
|
class MoEGate(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
|
|||||||
self.norm_topk_prob = config.norm_topk_prob
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
self.n_routed_experts = config.n_routed_experts
|
self.n_routed_experts = config.n_routed_experts
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.topk_method = config.topk_method
|
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||||
|
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
gates = x @ self.weight.T
|
return group_expert_select(
|
||||||
|
x @ self.weight.T,
|
||||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
self.e_score_correction_bias,
|
||||||
|
self.top_k,
|
||||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
self.n_group,
|
||||||
bsz, seq_len = x.shape[:2]
|
self.topk_group,
|
||||||
scores = scores + self.e_score_correction_bias
|
self.routed_scaling_factor,
|
||||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
self.norm_topk_prob,
|
||||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
)
|
||||||
k = self.n_group - self.topk_group
|
|
||||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
|
||||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
|
||||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
|
||||||
scores = scores.reshape(bsz, seq_len, -1)
|
|
||||||
|
|
||||||
k = self.top_k
|
|
||||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
|
||||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
|
||||||
if self.top_k > 1 and self.norm_topk_prob:
|
|
||||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
|
||||||
scores = scores / denominator
|
|
||||||
scores = scores * self.routed_scaling_factor
|
|
||||||
|
|
||||||
return inds, scores
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
@ -381,6 +397,10 @@ class DeepseekV3Model(nn.Module):
|
|||||||
DeepseekV3DecoderLayer(config, idx)
|
DeepseekV3DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pipeline_rank = 0
|
self.pipeline_rank = 0
|
||||||
self.pipeline_size = 1
|
self.pipeline_size = 1
|
||||||
@ -390,11 +410,15 @@ class DeepseekV3Model(nn.Module):
|
|||||||
# rank=pipeline_size-1 gets the first
|
# rank=pipeline_size-1 gets the first
|
||||||
self.pipeline_rank = group.rank()
|
self.pipeline_rank = group.rank()
|
||||||
self.pipeline_size = group.size()
|
self.pipeline_size = group.size()
|
||||||
layers_per_rank = (
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
len(self.layers) + self.pipeline_size - 1
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
) // self.pipeline_size
|
if self.pipeline_rank < extra:
|
||||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
layers_per_rank += 1
|
||||||
self.layers = self.layers[start : start + layers_per_rank]
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -412,15 +436,15 @@ class DeepseekV3Model(nn.Module):
|
|||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
for i in range(self.num_layers):
|
||||||
h = layer(h, mask, c)
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
@ -468,4 +492,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
195
llms/mlx_lm/models/granite.py
Normal file
195
llms/mlx_lm/models/granite.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
logits_scaling: float
|
||||||
|
attention_multiplier: float
|
||||||
|
embedding_multiplier: float
|
||||||
|
residual_multiplier: float
|
||||||
|
max_position_embeddings: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
attention_bias: bool
|
||||||
|
mlp_bias: bool
|
||||||
|
rope_theta: float
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
|
||||||
|
self.scale = args.attention_multiplier
|
||||||
|
attention_bias = args.attention_bias
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
|
self.rope = initialize_rope(
|
||||||
|
self.head_dim,
|
||||||
|
args.rope_theta,
|
||||||
|
False,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = scaled_dot_product_attention(
|
||||||
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
hidden_dim = args.intermediate_size
|
||||||
|
if hasattr(args, "mlp_bias"):
|
||||||
|
mlp_bias = args.mlp_bias
|
||||||
|
else:
|
||||||
|
mlp_bias = False
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.residual_multiplier = args.residual_multiplier
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r * self.residual_multiplier
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r * self.residual_multiplier
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.embedding_multiplier = args.embedding_multiplier
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs) * self.embedding_multiplier
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = GraniteModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.logits_scaling = args.logits_scaling
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, mask, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out / self.logits_scaling
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
@ -76,7 +76,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||||
if kv_proj:
|
if kv_proj:
|
||||||
self.k_proj = nn.Linear(
|
self.k_proj = nn.Linear(
|
||||||
@ -107,7 +106,6 @@ class Attention(nn.Module):
|
|||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
queries = self.q_proj(x)
|
queries = self.q_proj(x)
|
||||||
|
|
||||||
if kv_states is None:
|
if kv_states is None:
|
||||||
keys, values = self.k_proj(x), self.v_proj(x)
|
keys, values = self.k_proj(x), self.v_proj(x)
|
||||||
kv_states = keys, values
|
kv_states = keys, values
|
||||||
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = args.hidden_size
|
self.hidden_size = args.hidden_size
|
||||||
self.self_attn = Attention(kv_proj, args)
|
self.self_attn = Attention(kv_proj, args)
|
||||||
self.mlp = MoeBlock(args)
|
if args.num_experts == 1:
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
else:
|
||||||
|
self.mlp = MoeBlock(args)
|
||||||
|
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
self.post_attention_layernorm = nn.RMSNorm(
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
|
|||||||
assert self.vocab_size > 0
|
assert self.vocab_size > 0
|
||||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
DecoderLayer(
|
||||||
|
args=args,
|
||||||
|
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||||
|
)
|
||||||
for i in range(args.num_hidden_layers)
|
for i in range(args.num_hidden_layers)
|
||||||
]
|
]
|
||||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
|
|||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||||
if i % self.args.cla_share_factor == 0:
|
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||||
shared_kv_states = None
|
shared_kv_states = None
|
||||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||||
|
|
||||||
@ -275,6 +279,29 @@ class Model(nn.Module):
|
|||||||
return self.model.embed_tokens.as_linear(out)
|
return self.model.embed_tokens.as_linear(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
|
||||||
|
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||||
|
new_weights = {}
|
||||||
|
D = self.args.hidden_size
|
||||||
|
n_kv_heads = self.args.num_key_value_heads
|
||||||
|
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||||
|
head_dim = D // self.args.num_attention_heads
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "qkv_proj" in k:
|
||||||
|
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||||
|
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||||
|
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||||
|
k_new = k.replace("qkv_proj", k_up)
|
||||||
|
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||||
|
elif "gate_and_up_proj" in k:
|
||||||
|
splits = v.split(2, axis=0)
|
||||||
|
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||||
|
k_new = k.replace("gate_and_up_proj", k_up)
|
||||||
|
new_weights[k_new] = v_new
|
||||||
|
else:
|
||||||
|
new_weights[k] = v
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||||
return weights
|
return weights
|
||||||
for l in range(self.args.num_hidden_layers):
|
for l in range(self.args.num_hidden_layers):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024-2025 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, state=None):
|
def ssm_step(self, x, A, state=None):
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = mx.split(
|
delta, B, C = map(
|
||||||
deltaBC,
|
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
indices_or_sections=[
|
mx.split(
|
||||||
self.time_step_rank,
|
deltaBC,
|
||||||
self.time_step_rank + self.ssm_state_size,
|
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||||
],
|
axis=-1,
|
||||||
axis=-1,
|
),
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
|
|||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def _process_sequence(self, x, conv_cache, state_cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
if cache is None:
|
xz = self.in_proj(x)
|
||||||
cache = [None, None]
|
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||||
|
|
||||||
|
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||||
|
x = nn.silu(conv_out)
|
||||||
|
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
current_state = state_cache
|
||||||
|
y = []
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
xt = x[:, t, :]
|
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||||
xz = self.in_proj(xt)
|
y.append(y_t)
|
||||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
y = mx.stack(y, axis=1)
|
||||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
z = self.out_proj(nn.silu(z) * y)
|
||||||
x_t = conv_out.squeeze(1)
|
return z, (new_conv_cache, current_state)
|
||||||
x_t = nn.silu(x_t)
|
|
||||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
def __call__(self, x, cache):
|
||||||
z_t = nn.silu(z_t)
|
if cache is None:
|
||||||
output_t = y_t * z_t
|
conv_cache, state_cache = None, None
|
||||||
output_t = self.out_proj(output_t)
|
else:
|
||||||
outputs.append(output_t)
|
conv_cache, state_cache = cache[0], cache[1]
|
||||||
output = mx.stack(outputs, axis=1)
|
|
||||||
|
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||||
|
x, conv_cache, state_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cache, MambaCache):
|
||||||
|
cache[0] = new_conv_cache
|
||||||
|
cache[1] = new_state_cache
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
|||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||||
|
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||||
|
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -34,14 +36,24 @@ class ChatDataset:
|
|||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
self._data = [
|
self,
|
||||||
tokenizer.apply_chat_template(
|
data: List[Dict[str, str]],
|
||||||
d["messages"],
|
tokenizer: PreTrainedTokenizer,
|
||||||
tools=d.get("tools", None),
|
chat_key: str = "messages",
|
||||||
)
|
mask_prompt: bool = False,
|
||||||
for d in data
|
):
|
||||||
]
|
self._data = []
|
||||||
|
for d in data:
|
||||||
|
messages = d[chat_key]
|
||||||
|
tools = d.get("tools", None)
|
||||||
|
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||||
|
if mask_prompt:
|
||||||
|
messages = messages[:-1]
|
||||||
|
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -63,16 +75,36 @@ class CompletionsDataset:
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_key: str,
|
prompt_key: str,
|
||||||
completion_key: str,
|
completion_key: str,
|
||||||
|
mask_prompt: bool,
|
||||||
):
|
):
|
||||||
self._data = [
|
self._data = []
|
||||||
tokenizer.apply_chat_template(
|
for d in data:
|
||||||
|
tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": d[prompt_key]},
|
{"role": "user", "content": d[prompt_key]},
|
||||||
{"role": "assistant", "content": d[completion_key]},
|
{"role": "assistant", "content": d[completion_key]},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for d in data
|
if mask_prompt:
|
||||||
]
|
offset = len(
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": d[prompt_key]}]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._data.append((tokens, offset))
|
||||||
|
else:
|
||||||
|
self._data.append(tokens)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatenatedDataset:
|
||||||
|
def __init__(self, data: List[Any]):
|
||||||
|
self._data = list(itertools.chain(*data))
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
@ -84,18 +116,26 @@ class CompletionsDataset:
|
|||||||
def create_dataset(
|
def create_dataset(
|
||||||
data,
|
data,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
prompt_feature = prompt_feature or "prompt"
|
mask_prompt = getattr(config, "mask_prompt", False)
|
||||||
completion_feature = completion_feature or "completion"
|
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||||
|
text_feature = getattr(config, "text_feature", "text")
|
||||||
|
completion_feature = getattr(config, "completion_feature", "completion")
|
||||||
|
chat_feature = getattr(config, "chat_feature", "messages")
|
||||||
sample = data[0]
|
sample = data[0]
|
||||||
if "messages" in sample:
|
if prompt_feature in sample and completion_feature in sample:
|
||||||
return ChatDataset(data, tokenizer)
|
return CompletionsDataset(
|
||||||
elif prompt_feature in sample and completion_feature in sample:
|
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
)
|
||||||
elif "text" in sample:
|
elif chat_feature in sample:
|
||||||
return Dataset(data, tokenizer)
|
return ChatDataset(
|
||||||
|
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
|
||||||
|
)
|
||||||
|
elif text_feature in sample:
|
||||||
|
if mask_prompt:
|
||||||
|
raise ValueError("Prompt masking not supported for text dataset.")
|
||||||
|
return Dataset(data, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
@ -106,15 +146,14 @@ def create_dataset(
|
|||||||
def load_local_dataset(
|
def load_local_dataset(
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
def load_subset(path):
|
def load_subset(path):
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return []
|
return []
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
data = [json.loads(l) for l in fid]
|
data = [json.loads(l) for l in fid]
|
||||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
return create_dataset(data, tokenizer, config)
|
||||||
|
|
||||||
names = ("train", "valid", "test")
|
names = ("train", "valid", "test")
|
||||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||||
@ -124,8 +163,7 @@ def load_local_dataset(
|
|||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
data_id: str,
|
data_id: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_feature: Optional[str] = None,
|
config,
|
||||||
completion_feature: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
from datasets import exceptions, load_dataset
|
from datasets import exceptions, load_dataset
|
||||||
|
|
||||||
@ -136,9 +174,7 @@ def load_hf_dataset(
|
|||||||
|
|
||||||
train, valid, test = [
|
train, valid, test = [
|
||||||
(
|
(
|
||||||
create_dataset(
|
create_dataset(dataset[n], tokenizer, config)
|
||||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
if n in dataset.keys()
|
if n in dataset.keys()
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@ -154,42 +190,61 @@ def load_hf_dataset(
|
|||||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
hf_args = args.hf_dataset
|
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||||
dataset_name = hf_args["name"]
|
|
||||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
|
||||||
text_feature = hf_args.get("text_feature")
|
|
||||||
prompt_feature = hf_args.get("prompt_feature")
|
|
||||||
completion_feature = hf_args.get("completion_feature")
|
|
||||||
|
|
||||||
def create_hf_dataset(split: str = None):
|
|
||||||
ds = datasets.load_dataset(
|
ds = datasets.load_dataset(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
split=split,
|
split=split,
|
||||||
**hf_args.get("config", {}),
|
**hf_config,
|
||||||
)
|
)
|
||||||
if prompt_feature and completion_feature:
|
return create_dataset(ds, tokenizer, config)
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
|
||||||
elif text_feature:
|
dataset_collection = args.hf_dataset
|
||||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
if isinstance(dataset_collection, dict):
|
||||||
else:
|
dataset_collection = [dataset_collection]
|
||||||
raise ValueError(
|
|
||||||
"Specify either a prompt and completion feature or a text "
|
collection = []
|
||||||
"feature for the Hugging Face dataset."
|
for ds in dataset_collection:
|
||||||
|
ds_name = ds["name"]
|
||||||
|
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||||
|
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||||
|
config = types.SimpleNamespace(**ds)
|
||||||
|
hf_config = ds.get("config", {})
|
||||||
|
if args.train:
|
||||||
|
train_split = ds.get("train_split", "train[:80%]")
|
||||||
|
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||||
|
train = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
train_split,
|
||||||
|
hf_config,
|
||||||
)
|
)
|
||||||
|
valid = create_hf_dataset(
|
||||||
|
ds_name,
|
||||||
|
config,
|
||||||
|
valid_split,
|
||||||
|
hf_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train, valid = [], []
|
||||||
|
|
||||||
if args.train:
|
if args.test:
|
||||||
train_split = hf_args.get("train_split", "train[:80%]")
|
test_split = ds.get("test_split")
|
||||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
test = create_hf_dataset(
|
||||||
train = create_hf_dataset(split=train_split)
|
ds_name,
|
||||||
valid = create_hf_dataset(split=valid_split)
|
config,
|
||||||
else:
|
test_split,
|
||||||
train, valid = [], []
|
hf_config,
|
||||||
if args.test:
|
)
|
||||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
else:
|
||||||
else:
|
test = []
|
||||||
test = []
|
|
||||||
|
|
||||||
return train, valid, test
|
collection.append((train, valid, test))
|
||||||
|
|
||||||
|
if len(collection) == 1:
|
||||||
|
return collection[0]
|
||||||
|
|
||||||
|
# Otherwise concatenate them
|
||||||
|
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||||
else:
|
else:
|
||||||
data_path = Path(args.data)
|
data_path = Path(args.data)
|
||||||
|
|
||||||
prompt_feature = getattr(args, "prompt_feature", None)
|
|
||||||
completion_feature = getattr(args, "completion_feature", None)
|
|
||||||
if data_path.exists():
|
if data_path.exists():
|
||||||
train, valid, test = load_local_dataset(
|
train, valid, test = load_local_dataset(data_path, tokenizer, args)
|
||||||
data_path, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(
|
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||||
args.data, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -5,13 +5,16 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .datasets import CompletionsDataset
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -63,20 +66,30 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_loss(model, inputs, targets, lengths):
|
def default_loss(model, batch, lengths):
|
||||||
|
inputs = batch[:, :-1]
|
||||||
|
targets = batch[:, 1:]
|
||||||
|
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logits = logits.astype(mx.float32)
|
logits = logits.astype(mx.float32)
|
||||||
|
|
||||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
steps = mx.arange(1, targets.shape[1] + 1)
|
||||||
|
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||||
|
|
||||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||||
ntoks = length_mask.sum()
|
ntoks = mask.sum()
|
||||||
ce = ce.sum() / ntoks
|
ce = ce.sum() / ntoks
|
||||||
|
|
||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
max_seq_length,
|
||||||
|
train=False,
|
||||||
|
):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
indices = np.random.permutation(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
if len(batch[0]) == 2:
|
||||||
|
batch, offsets = zip(*batch)
|
||||||
|
else:
|
||||||
|
offsets = [0] * len(batch)
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
print(
|
print(
|
||||||
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
truncated_length # Update lengths to match truncated lengths
|
truncated_length # Update lengths to match truncated lengths
|
||||||
)
|
)
|
||||||
batch = mx.array(batch_arr)
|
batch = mx.array(batch_arr)
|
||||||
|
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -140,8 +156,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = mx.array(0.0)
|
||||||
ntokens = 0
|
ntokens = mx.array(0)
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
@ -217,8 +233,8 @@ def train(
|
|||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
|
train_time = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
@ -229,10 +245,11 @@ def train(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
tic = time.perf_counter()
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
# is always measured before any training.
|
# is always measured before any training.
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
val_loss = evaluate(
|
val_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
@ -243,7 +260,7 @@ def train(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - tic
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
@ -260,24 +277,23 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_val_loss_report(val_info)
|
training_callback.on_val_loss_report(val_info)
|
||||||
|
|
||||||
start = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
losses += lvalue
|
losses += lvalue
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
mx.eval(state, losses, n_tokens)
|
mx.eval(state, losses, n_tokens)
|
||||||
|
train_time += time.perf_counter() - tic
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / train_time
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / train_time
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -306,7 +322,7 @@ def train(
|
|||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
start = time.perf_counter()
|
train_time = 0
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
|
@ -89,11 +89,13 @@ def linear_to_lora_layers(
|
|||||||
"mixtral",
|
"mixtral",
|
||||||
"nemotron",
|
"nemotron",
|
||||||
"stablelm",
|
"stablelm",
|
||||||
|
"hunyuan",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"phimoe",
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"granite",
|
||||||
"helium",
|
"helium",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"cohere",
|
"cohere",
|
||||||
|
@ -13,7 +13,18 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -65,6 +76,7 @@ class GenerationResponse:
|
|||||||
Args:
|
Args:
|
||||||
text (str): The next segment of decoded text. This can be an empty string.
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
token (int): The next token.
|
token (int): The next token.
|
||||||
|
from_draft (bool): Whether the token was generated by the draft model.
|
||||||
logprobs (mx.array): A vector of log probabilities.
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
prompt_tokens (int): The number of tokens in the prompt.
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
prompt_tps (float): The prompt processing tokens-per-second.
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
@ -77,6 +89,7 @@ class GenerationResponse:
|
|||||||
text: str
|
text: str
|
||||||
token: int
|
token: int
|
||||||
logprobs: mx.array
|
logprobs: mx.array
|
||||||
|
from_draft: bool
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
prompt_tps: float
|
prompt_tps: float
|
||||||
generation_tokens: int
|
generation_tokens: int
|
||||||
@ -338,7 +351,7 @@ def speculative_generate_step(
|
|||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -365,7 +378,8 @@ def speculative_generate_step(
|
|||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||||
|
and a bool indicating if the token was generated by the draft model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -450,12 +464,12 @@ def speculative_generate_step(
|
|||||||
break
|
break
|
||||||
n += 1
|
n += 1
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tn, lpn
|
yield tn, lpn, True
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
if ntoks < max_tokens:
|
if ntoks < max_tokens:
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tokens[n], logprobs[n]
|
yield tokens[n], logprobs[n], False
|
||||||
|
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
@ -463,7 +477,7 @@ def speculative_generate_step(
|
|||||||
y = mx.array([tokens[n]], mx.uint32)
|
y = mx.array([tokens[n]], mx.uint32)
|
||||||
draft_y = y
|
draft_y = y
|
||||||
|
|
||||||
# If we accpeted all the draft tokens, include the last
|
# If we accepted all the draft tokens, include the last
|
||||||
# draft token in the next draft step since it hasn't been
|
# draft token in the next draft step since it hasn't been
|
||||||
# processed yet by the draft model
|
# processed yet by the draft model
|
||||||
if n == num_draft:
|
if n == num_draft:
|
||||||
@ -518,6 +532,10 @@ def stream_generate(
|
|||||||
if draft_model is None:
|
if draft_model is None:
|
||||||
kwargs.pop("num_draft_tokens", None)
|
kwargs.pop("num_draft_tokens", None)
|
||||||
token_generator = generate_step(prompt, model, **kwargs)
|
token_generator = generate_step(prompt, model, **kwargs)
|
||||||
|
# from_draft always false for non-speculative generation
|
||||||
|
token_generator = (
|
||||||
|
(token, logprobs, False) for token, logprobs in token_generator
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kwargs.pop("max_kv_size", None)
|
kwargs.pop("max_kv_size", None)
|
||||||
token_generator = speculative_generate_step(
|
token_generator = speculative_generate_step(
|
||||||
@ -526,7 +544,7 @@ def stream_generate(
|
|||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs) in enumerate(token_generator):
|
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
@ -540,6 +558,7 @@ def stream_generate(
|
|||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
@ -553,6 +572,7 @@ def stream_generate(
|
|||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
@ -627,6 +647,7 @@ def load_config(model_path: Path) -> dict:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@ -638,6 +659,8 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
strict (bool): Whether or not to raise an exception if weights don't
|
||||||
|
match. Default: ``True``
|
||||||
model_config (dict, optional): Optional configuration parameters for the
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
model. Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
@ -660,7 +683,7 @@ def load_model(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files and strict:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
@ -694,7 +717,7 @@ def load_model(
|
|||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||||
|
|
||||||
def test_hf(self):
|
def test_hf(self):
|
||||||
|
hf_args = {
|
||||||
|
"name": "billsum",
|
||||||
|
"prompt_feature": "text",
|
||||||
|
"completion_feature": "summary",
|
||||||
|
"train_split": "train[:2%]",
|
||||||
|
"valid_split": "train[-2%:]",
|
||||||
|
}
|
||||||
args = types.SimpleNamespace(
|
args = types.SimpleNamespace(
|
||||||
hf_dataset={
|
hf_dataset=hf_args,
|
||||||
"name": "billsum",
|
|
||||||
"prompt_feature": "text",
|
|
||||||
"completion_feature": "summary",
|
|
||||||
"train_split": "train[:2%]",
|
|
||||||
"valid_split": "train[-2%:]",
|
|
||||||
},
|
|
||||||
test=False,
|
test=False,
|
||||||
train=True,
|
train=True,
|
||||||
)
|
)
|
||||||
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(len(valid[0]) > 0)
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
self.assertEqual(len(test), 0)
|
self.assertEqual(len(test), 0)
|
||||||
|
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset=[hf_args, hf_args],
|
||||||
|
test=False,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertEqual(2 * len(train), len(train_double))
|
||||||
|
self.assertEqual(2 * len(valid), len(valid_double))
|
||||||
|
self.assertEqual(2 * len(test), len(test_double))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,17 +1,24 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from mlx_lm.sample_utils import make_logits_processors
|
from mlx_lm.sample_utils import make_logits_processors
|
||||||
from mlx_lm.utils import generate, load
|
from mlx_lm.utils import (
|
||||||
|
GenerationResponse,
|
||||||
|
generate,
|
||||||
|
load,
|
||||||
|
make_sampler,
|
||||||
|
stream_generate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGenerate(unittest.TestCase):
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
# Simple test that generation runs
|
# Simple test that generation runs
|
||||||
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
def test_stream_generate_speculative(self):
|
||||||
|
# Use same model as draft model, this is not a speed test
|
||||||
|
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||||
|
|
||||||
|
results: List[GenerationResponse] = []
|
||||||
|
drafted: List[bool] = []
|
||||||
|
|
||||||
|
# make a determinate sampler
|
||||||
|
sampler = make_sampler(temp=0.0)
|
||||||
|
|
||||||
|
for generation_result in stream_generate(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompt="hello",
|
||||||
|
max_tokens=5,
|
||||||
|
draft_model=draft_model,
|
||||||
|
num_draft_tokens=2,
|
||||||
|
sampler=sampler,
|
||||||
|
):
|
||||||
|
drafted.append(generation_result.from_draft)
|
||||||
|
results.append(generation_result)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), 5)
|
||||||
|
# since num_draft_tokens is 2 and draft model is the same, the
|
||||||
|
# first 2 generations should be drafts, the third should come
|
||||||
|
# from the target model, and last two should be drafts
|
||||||
|
self.assertEqual(drafted, [True, True, False, True, True])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -134,9 +134,7 @@ def find_alignment(
|
|||||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||||
# consider only the logits associated with predicting text
|
# consider only the logits associated with predicting text
|
||||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
|
||||||
sampled_logits.dtype
|
|
||||||
)
|
|
||||||
text_token_probs = mx.take_along_axis(
|
text_token_probs = mx.take_along_axis(
|
||||||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
@ -144,10 +142,11 @@ def find_alignment(
|
|||||||
|
|
||||||
# heads * tokens * frames
|
# heads * tokens * frames
|
||||||
weights = mx.stack(
|
weights = mx.stack(
|
||||||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
[cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
|
||||||
)
|
)
|
||||||
weights = weights[:, :, : num_frames // 2]
|
weights = weights[:, :, : num_frames // 2]
|
||||||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
|
||||||
|
weights = weights.astype(mx.float32)
|
||||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||||
weights = (weights - mean) / std
|
weights = (weights - mean) / std
|
||||||
|
@ -195,6 +195,8 @@ def transcribe(
|
|||||||
seek_points.append(0)
|
seek_points.append(0)
|
||||||
if len(seek_points) % 2 == 1:
|
if len(seek_points) % 2 == 1:
|
||||||
seek_points.append(content_frames)
|
seek_points.append(content_frames)
|
||||||
|
else:
|
||||||
|
seek_points[-1] = min(content_frames, seek_points[-1])
|
||||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
w = mx.softmax(qk, axis=-1, precise=True)
|
w = mx.softmax(qk, axis=-1, precise=True)
|
||||||
out = (w @ v).transpose(0, 2, 1, 3)
|
out = (w @ v).transpose(0, 2, 1, 3)
|
||||||
out = out.reshape(n_batch, n_ctx, n_state)
|
out = out.reshape(n_batch, n_ctx, n_state)
|
||||||
return out, qk.astype(mx.float32)
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user