mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
8 Commits
flux-dist-
...
dist-eval
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f787c08585 | ||
|
|
d5f49d65b9 | ||
|
|
4385363c0f | ||
|
|
9a3ddc3e65 | ||
|
|
df1406735b | ||
|
|
07f88f8057 | ||
|
|
50f0a7f6d9 | ||
|
|
6ae6c72c2e |
@@ -261,23 +261,19 @@ if __name__ == "__main__":
|
|||||||
generate_progress_images(0, flux, args)
|
generate_progress_images(0, flux, args)
|
||||||
|
|
||||||
grads = None
|
grads = None
|
||||||
batch_cnt = 0
|
losses = []
|
||||||
total_loss = 0
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
total_loss = total_loss + loss
|
mx.eval(loss, grads, state)
|
||||||
batch_cnt += 1
|
losses.append(loss.item())
|
||||||
mx.eval(total_loss, grads, state)
|
|
||||||
|
|
||||||
if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
|
if (i + 1) % 10 == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||||
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
|
|
||||||
total_loss = total_loss.item() / batch_cnt
|
|
||||||
print(
|
print(
|
||||||
f"Iter: {i + 1} Loss: {total_loss:.3f} "
|
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||||
f"It/s: {batch_cnt / (toc - tic):.3f} "
|
f"It/s: {10 / (toc - tic):.3f} "
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
f"Peak mem: {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
@@ -289,8 +285,7 @@ if __name__ == "__main__":
|
|||||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
total_loss = 0
|
losses = []
|
||||||
batch_cnt = 0
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
save_adapters("final_adapters.safetensors", flux, args)
|
save_adapters("final_adapters.safetensors", flux, args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import lm_eval
|
import lm_eval
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -20,11 +20,10 @@ from lm_eval.api.model import LM
|
|||||||
from lm_eval.api.registry import register_model
|
from lm_eval.api.registry import register_model
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .models.base import create_causal_mask
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import load, stream_generate
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
PAD = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _len_longest_common_prefix(a, b):
|
def _len_longest_common_prefix(a, b):
|
||||||
l = 0
|
l = 0
|
||||||
@@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
|
|||||||
return s[: min(f)]
|
return s[: min(f)]
|
||||||
|
|
||||||
|
|
||||||
def _pad_inputs(
|
def _pad_inputs(inputs):
|
||||||
inputs,
|
lengths = np.array([len(x) for x in inputs])
|
||||||
maxlen,
|
maxlen = lengths.max()
|
||||||
genlen=0,
|
padded = np.stack(
|
||||||
pad_left=False,
|
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
|
||||||
pad_multiple=32,
|
|
||||||
truncate=False,
|
|
||||||
):
|
|
||||||
# pad the prompts to the left with at least genlen tokens.
|
|
||||||
actual_maxlen = max(len(p) for p in inputs) + genlen
|
|
||||||
if actual_maxlen > maxlen:
|
|
||||||
if not truncate:
|
|
||||||
raise ValueError("Inputs are too long.")
|
|
||||||
else: # drop begining
|
|
||||||
actual_maxlen = maxlen
|
|
||||||
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
|
||||||
if pad_multiple > 0:
|
|
||||||
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
|
||||||
maxlen *= pad_multiple
|
|
||||||
assert PAD == 0
|
|
||||||
lr = np.array((1, 0) if pad_left else (0, 1))
|
|
||||||
return np.stack(
|
|
||||||
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
return mx.array(padded), mx.array(lengths)
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlxlm")
|
@register_model("mlxlm")
|
||||||
@@ -83,32 +65,33 @@ class MLXLM(LM):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||||
self.use_chat_template = use_chat_template or (
|
self.use_chat_template = use_chat_template and (
|
||||||
self.tokenizer.chat_template is not None
|
self.tokenizer.chat_template is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
def _score_fn(self, inputs, step_size: int = 64):
|
||||||
if tokenize:
|
inputs, lengths = _pad_inputs(inputs)
|
||||||
inputs = self._tokenize(inputs)
|
|
||||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
|
||||||
inputs = mx.array(inputs)
|
|
||||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||||
|
|
||||||
cache = make_prompt_cache(self._model)
|
cache = make_prompt_cache(self._model)
|
||||||
|
|
||||||
mask = targets != PAD
|
|
||||||
|
|
||||||
scores, is_greedy = [], []
|
scores, is_greedy = [], []
|
||||||
for i in range(0, inputs.shape[1], step_size):
|
for i in range(0, inputs.shape[1], step_size):
|
||||||
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
inp = inputs[:, i : i + step_size]
|
||||||
|
T = inp.shape[1]
|
||||||
|
|
||||||
|
offset = cache[0].offset
|
||||||
|
mask = create_causal_mask(T, offset, lengths=lengths)
|
||||||
|
mask = mask == 0
|
||||||
|
|
||||||
|
logits = self._model(inp, cache=cache, mask=mask)
|
||||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||||
|
|
||||||
score = mx.take_along_axis(
|
score = mx.take_along_axis(
|
||||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||||
)[..., 0]
|
)[..., 0]
|
||||||
ig = mask[:, i : i + step_size] * (
|
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||||
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)
|
||||||
)
|
|
||||||
|
|
||||||
mx.eval(score, ig)
|
mx.eval(score, ig)
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
@@ -119,38 +102,32 @@ class MLXLM(LM):
|
|||||||
scores = mx.concatenate(scores, axis=1)
|
scores = mx.concatenate(scores, axis=1)
|
||||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||||
|
|
||||||
return scores, mask.sum(axis=-1), is_greedy
|
return scores, lengths, is_greedy
|
||||||
|
|
||||||
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
def _loglikelihood(self, texts, score_spans=None):
|
||||||
# sort by length to get batches with little padding.
|
all_scores = mx.zeros(len(texts))
|
||||||
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
|
||||||
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
for i in tqdm(range(0, len(texts), self._batch_size)):
|
||||||
sorted_spans = None
|
batch = texts[i : i + self._batch_size]
|
||||||
|
scores, lengths, is_greedy = self._score_fn(batch)
|
||||||
|
|
||||||
|
ind = np.arange(scores.shape[-1])
|
||||||
if score_spans is not None:
|
if score_spans is not None:
|
||||||
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
spans = score_spans[i : i + self._batch_size]
|
||||||
|
lengths = [end - start for start, end in spans]
|
||||||
|
masks = mx.array(
|
||||||
|
np.array([(ind >= start) & (ind < end) for start, end in spans])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masks = ind[None] < lengths[:, None]
|
||||||
|
|
||||||
results = []
|
scores = (masks * scores).sum(axis=-1)
|
||||||
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
is_greedy = (masks * is_greedy).sum(axis=-1)
|
||||||
batch = sorted_inputs[i : i + self._batch_size]
|
|
||||||
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
|
||||||
for j in range(len(batch)):
|
|
||||||
if sorted_spans is None: # full sequence score
|
|
||||||
mask = mx.arange(scores[j].shape[-1]) < length
|
|
||||||
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
|
||||||
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
|
||||||
else: # subsequence score
|
|
||||||
start, end = sorted_spans[i + j]
|
|
||||||
score = scores[j][start:end].astype(mx.float32).sum()
|
|
||||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
|
||||||
length = end - start
|
|
||||||
|
|
||||||
results.append((score.item(), ig.item(), length))
|
all_scores[i : i + self._batch_size] = scores
|
||||||
|
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
|
||||||
|
|
||||||
# reorder the outputs
|
return all_scores, all_is_greedy
|
||||||
inv_sort = np.argsort(sorted_indices)
|
|
||||||
results = [results[inv_sort[i]] for i in range(len(results))]
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _tokenize(self, texts):
|
def _tokenize(self, texts):
|
||||||
return [
|
return [
|
||||||
@@ -222,16 +199,53 @@ class MLXLM(LM):
|
|||||||
+ "completion longer than context."
|
+ "completion longer than context."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_results = len(shortened)
|
||||||
|
|
||||||
|
# sort by length to get batches with little padding.
|
||||||
|
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
|
||||||
|
shortened = [shortened[i] for i in sorted_indices]
|
||||||
|
completion_spans = [completion_spans[i] for i in sorted_indices]
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
|
||||||
|
# split strided so we have approximately the same lengths on each node
|
||||||
|
shortened = shortened[group.rank() :: group.size()]
|
||||||
|
completion_spans = completion_spans[group.rank() :: group.size()]
|
||||||
|
|
||||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||||
results = self._loglikelihood(
|
scores, is_greedy = self._loglikelihood(
|
||||||
shortened,
|
shortened,
|
||||||
score_spans=completion_spans,
|
score_spans=completion_spans,
|
||||||
tokenize=False,
|
|
||||||
)
|
)
|
||||||
return [(r[0], r[1] == r[2]) for r in results]
|
|
||||||
|
# all gather the results across groups
|
||||||
|
if group.size() > 1:
|
||||||
|
per_group = int(np.ceil(num_results / group.size()))
|
||||||
|
scores = mx.pad(scores, ((0, per_group - len(scores)),))
|
||||||
|
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
|
||||||
|
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
|
||||||
|
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
|
||||||
|
scores = scores.T.reshape(-1)
|
||||||
|
is_greedy = is_greedy.T.reshape(-1)
|
||||||
|
|
||||||
|
scores = np.array(scores[:num_results])
|
||||||
|
is_greedy = np.array(is_greedy[:num_results])
|
||||||
|
|
||||||
|
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
|
||||||
|
inv_sort = np.argsort(sorted_indices)
|
||||||
|
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
|
||||||
|
return results
|
||||||
|
|
||||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||||
|
) -> str:
|
||||||
|
if len(chat_history) == 0:
|
||||||
|
return ""
|
||||||
|
return lm_eval.models.huggingface.HFLM.apply_chat_template(
|
||||||
|
chat_history, add_generation_prompt
|
||||||
|
)
|
||||||
|
|
||||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||||
@@ -268,8 +282,9 @@ class MLXLM(LM):
|
|||||||
logging.info(
|
logging.info(
|
||||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||||
)
|
)
|
||||||
inputs = [req.args[0] for req in requests]
|
inputs = self._tokenize([req.args[0] for req in requests])
|
||||||
return [t[0] for t in self._loglikelihood(inputs)]
|
scores, _ = self._loglikelihood(inputs)
|
||||||
|
return scores.tolist()
|
||||||
|
|
||||||
def generate_until(self, requests) -> list[str]:
|
def generate_until(self, requests) -> list[str]:
|
||||||
"""Generate greedily until a stopping sequence
|
"""Generate greedily until a stopping sequence
|
||||||
@@ -332,7 +347,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--limit",
|
"--limit",
|
||||||
default=1.0,
|
default=None,
|
||||||
help="Limit the number of examples per task.",
|
help="Limit the number of examples per task.",
|
||||||
type=float,
|
type=float,
|
||||||
)
|
)
|
||||||
@@ -346,11 +361,8 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply-chat-template",
|
"--apply-chat-template",
|
||||||
action=argparse.BooleanOptionalAction,
|
action="store_true",
|
||||||
help="Specifies whether to apply a chat template to the prompt. If "
|
help="Specifies whether to apply a chat template to the prompt.",
|
||||||
"the model has a chat template, this defaults to `True`, "
|
|
||||||
"otherwise `False`.",
|
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,11 @@ import mlx.core as mx
|
|||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx-community/DeepSeek-R1-3bit",
|
||||||
|
help="HF repo or path to local model.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
@@ -37,9 +42,7 @@ parser.add_argument(
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
model, tokenizer = load(args.model, lazy=True)
|
||||||
|
|
||||||
model, tokenizer = load(model_repo, lazy=True)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@@ -135,6 +136,7 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@@ -156,6 +158,7 @@ def build_parser():
|
|||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -400,6 +400,8 @@ class DeepseekV3Model(nn.Module):
|
|||||||
|
|
||||||
pipeline_rank = self.pipeline_rank
|
pipeline_rank = self.pipeline_rank
|
||||||
pipeline_size = self.pipeline_size
|
pipeline_size = self.pipeline_size
|
||||||
|
# Hack to avoid time-outs during prompt-processing
|
||||||
|
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
@@ -407,18 +409,21 @@ class DeepseekV3Model(nn.Module):
|
|||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
for layer, c in zip(self.layers, cache):
|
||||||
h = layer(h, mask, c)
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
|
h = mx.distributed.send(
|
||||||
|
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||||
|
)
|
||||||
|
|
||||||
# Broadcast h while keeping it in the graph
|
# Broadcast h while keeping it in the graph
|
||||||
h = mx.distributed.all_gather(h)[: h.shape[0]]
|
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
|
|||||||
241
llms/mlx_lm/models/internlm3.py
Normal file
241
llms/mlx_lm/models/internlm3.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
bias: bool = False
|
||||||
|
qkv_bias: bool = False
|
||||||
|
max_position_embeddings: int = 32768
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "rope_type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicNTKScalingRoPE(nn.Module):
|
||||||
|
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
max_position_embeddings: int = 2048,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.original_base = base
|
||||||
|
self.dims = dims
|
||||||
|
self.traditional = traditional
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
seq_len = x.shape[1] + offset
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.original_base * (
|
||||||
|
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||||
|
) ** (self.dims / (self.dims - 2))
|
||||||
|
else:
|
||||||
|
base = self.original_base
|
||||||
|
|
||||||
|
return mx.fast.rope(
|
||||||
|
x,
|
||||||
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=base,
|
||||||
|
scale=self.scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
qkv_bias = args.qkv_bias
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None
|
||||||
|
and args.rope_scaling["rope_type"] == "linear"
|
||||||
|
else 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = DynamicNTKScalingRoPE(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = scaled_dot_product_attention(
|
||||||
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, bias):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Any] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
assert args.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = InternLM2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, mask, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
@@ -170,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
if prompt_feature and completion_feature:
|
if prompt_feature and completion_feature:
|
||||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||||
elif text_feature:
|
elif text_feature:
|
||||||
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Specify either a prompt and completion feature or a text "
|
"Specify either a prompt and completion feature or a text "
|
||||||
|
|||||||
@@ -159,8 +159,8 @@ def evaluate(
|
|||||||
ntokens += toks
|
ntokens += toks
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses)
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||||
ntokens = mx.distributed.all_sum(ntokens)
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||||
|
|
||||||
return (all_losses / ntokens).item()
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
@@ -272,9 +272,9 @@ def train(
|
|||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = mx.distributed.all_sum(losses).item()
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||||
train_loss /= steps * mx.distributed.init().size()
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ def linear_to_lora_layers(
|
|||||||
"minicpm",
|
"minicpm",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"olmo2",
|
"olmo2",
|
||||||
|
"internlm3",
|
||||||
]:
|
]:
|
||||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||||
if model.model_type in ["mixtral", "phimoe"]:
|
if model.model_type in ["mixtral", "phimoe"]:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def swapped_with_identity(obj, func):
|
def swapped_with_identity(obj, func):
|
||||||
old_func = getattr(obj, func)
|
old_func = getattr(obj, func)
|
||||||
setattr(obj, func, lambda x: x)
|
setattr(obj, func, lambda x, **kwargs: x)
|
||||||
yield
|
yield
|
||||||
setattr(obj, func, old_func)
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
|||||||
@@ -927,6 +927,23 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_internlm3(self):
|
||||||
|
from mlx_lm.models import internlm3
|
||||||
|
|
||||||
|
args = internlm3.ModelArgs(
|
||||||
|
model_type="internlm3",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
model = internlm3.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user