mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
comments + docs
This commit is contained in:
@@ -20,6 +20,31 @@ The `mlx-lm` package also has:
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
### Quick Start
|
||||
|
||||
To generate text with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate --prompt "Hi!"
|
||||
```
|
||||
|
||||
To chat with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.chat
|
||||
```
|
||||
|
||||
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||
chat context is preserved during the lifetime of the REPL.
|
||||
|
||||
Commands in `mlx-lm` typically take command line options which let you specify
|
||||
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||
options for a command, e.g.:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate -h
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
@@ -138,7 +163,7 @@ mlx_lm.convert \
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
MLX LM has some tools to scale efficiently to long prompts and generations:
|
||||
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||
|
||||
- A rotating fixed-size key-value cache.
|
||||
- Prompt caching
|
||||
@@ -178,7 +203,7 @@ for more usage details.
|
||||
|
||||
### Supported Models
|
||||
|
||||
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
@@ -56,7 +56,7 @@ def main():
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
print(f"Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
|
@@ -14,6 +14,7 @@ DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
@@ -27,6 +28,7 @@ def setup_arg_parser():
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
|
@@ -44,12 +44,11 @@ def save_prompt_cache(
|
||||
metadata (Optional[Dict[str, str]]): Optional metadata to save along
|
||||
with model state..
|
||||
"""
|
||||
cache_data, cache_info = zip(*(c.state for c in cache))
|
||||
cache_data = [c.state for c in cache]
|
||||
cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache]
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_classes, cache_info]
|
||||
if metadata:
|
||||
cache_metadata.append(metadata)
|
||||
cache_metadata = [cache_classes, cache_info, metadata or ""]
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
@@ -70,12 +69,14 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||
arrays = tree_unflatten(list(arrays.items()))
|
||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||
classes, info = cache_metadata[:2]
|
||||
classes, info, metadata = cache_metadata
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, *state in zip(cache, arrays, info):
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
if hasattr(c, "meta_state"):
|
||||
c.meta_state = meta_state
|
||||
if return_metadata:
|
||||
return cache, cache_metadata[2]
|
||||
return cache, metadata
|
||||
return cache
|
||||
|
||||
|
||||
@@ -114,16 +115,16 @@ class KVCache:
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys.shape[2]:
|
||||
return (self.keys, self.values), ""
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return (
|
||||
self.keys[..., : self.offset, :],
|
||||
self.values[..., : self.offset, :],
|
||||
), ""
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v[0]
|
||||
self.keys, self.values = v
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
|
||||
@@ -236,20 +237,25 @@ class RotatingKVCache:
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset < self.keys.shape[2]:
|
||||
kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset])
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
else:
|
||||
kv_state = (self.keys, self.values)
|
||||
extra_state = tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
return kv_state, extra_state
|
||||
return self.keys, self.values
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v[0]
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||
int,
|
||||
v[1],
|
||||
v,
|
||||
)
|
||||
|
||||
|
||||
@@ -267,10 +273,6 @@ class MambaCache:
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache, ""
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v[0]
|
||||
self.cache = v
|
||||
|
@@ -240,7 +240,7 @@ def generate_step(
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
mx.eval([c.state[0] for c in cache])
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
@@ -40,8 +40,8 @@ class TestPromptCache(unittest.TestCase):
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Test with metadata
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
@@ -67,8 +67,8 @@ class TestPromptCache(unittest.TestCase):
|
||||
self.assertEqual(c.keep, lc.keep)
|
||||
self.assertEqual(c.max_size, lc.max_size)
|
||||
self.assertEqual(c.step, lc.step)
|
||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Do a couple single token updates to get a rotation
|
||||
for _ in range(2):
|
||||
|
Reference in New Issue
Block a user