mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-10 13:07:28 +08:00
comments + docs
This commit is contained in:
@@ -20,6 +20,31 @@ The `mlx-lm` package also has:
|
|||||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
To generate text with an LLM use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.generate --prompt "Hi!"
|
||||||
|
```
|
||||||
|
|
||||||
|
To chat with an LLM use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.chat
|
||||||
|
```
|
||||||
|
|
||||||
|
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||||
|
chat context is preserved during the lifetime of the REPL.
|
||||||
|
|
||||||
|
Commands in `mlx-lm` typically take command line options which let you specify
|
||||||
|
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||||
|
options for a command, e.g.:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.generate -h
|
||||||
|
```
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
You can use `mlx-lm` as a module:
|
You can use `mlx-lm` as a module:
|
||||||
@@ -138,7 +163,7 @@ mlx_lm.convert \
|
|||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
|
||||||
MLX LM has some tools to scale efficiently to long prompts and generations:
|
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||||
|
|
||||||
- A rotating fixed-size key-value cache.
|
- A rotating fixed-size key-value cache.
|
||||||
- Prompt caching
|
- Prompt caching
|
||||||
@@ -178,7 +203,7 @@ for more usage details.
|
|||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||||
run is not supported, file an
|
run is not supported, file an
|
||||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||||
submit a pull request.
|
submit a pull request.
|
||||||
|
@@ -56,7 +56,7 @@ def main():
|
|||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
@@ -14,6 +14,7 @@ DEFAULT_MAX_TOKENS = 100
|
|||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.6
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
@@ -27,6 +28,7 @@ def setup_arg_parser():
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
default=DEFAULT_MODEL,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-path",
|
"--adapter-path",
|
||||||
|
@@ -44,12 +44,11 @@ def save_prompt_cache(
|
|||||||
metadata (Optional[Dict[str, str]]): Optional metadata to save along
|
metadata (Optional[Dict[str, str]]): Optional metadata to save along
|
||||||
with model state..
|
with model state..
|
||||||
"""
|
"""
|
||||||
cache_data, cache_info = zip(*(c.state for c in cache))
|
cache_data = [c.state for c in cache]
|
||||||
|
cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache]
|
||||||
cache_data = dict(tree_flatten(cache_data))
|
cache_data = dict(tree_flatten(cache_data))
|
||||||
cache_classes = [type(c).__name__ for c in cache]
|
cache_classes = [type(c).__name__ for c in cache]
|
||||||
cache_metadata = [cache_classes, cache_info]
|
cache_metadata = [cache_classes, cache_info, metadata or ""]
|
||||||
if metadata:
|
|
||||||
cache_metadata.append(metadata)
|
|
||||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||||
|
|
||||||
@@ -70,12 +69,14 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||||
arrays = tree_unflatten(list(arrays.items()))
|
arrays = tree_unflatten(list(arrays.items()))
|
||||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||||
classes, info = cache_metadata[:2]
|
classes, info, metadata = cache_metadata
|
||||||
cache = [globals()[c]() for c in classes]
|
cache = [globals()[c]() for c in classes]
|
||||||
for c, *state in zip(cache, arrays, info):
|
for c, state, meta_state in zip(cache, arrays, info):
|
||||||
c.state = state
|
c.state = state
|
||||||
|
if hasattr(c, "meta_state"):
|
||||||
|
c.meta_state = meta_state
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
return cache, cache_metadata[2]
|
return cache, metadata
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
@@ -114,16 +115,16 @@ class KVCache:
|
|||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
if self.offset == self.keys.shape[2]:
|
if self.offset == self.keys.shape[2]:
|
||||||
return (self.keys, self.values), ""
|
return self.keys, self.values
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
self.keys[..., : self.offset, :],
|
self.keys[..., : self.offset, :],
|
||||||
self.values[..., : self.offset, :],
|
self.values[..., : self.offset, :],
|
||||||
), ""
|
)
|
||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, v):
|
def state(self, v):
|
||||||
self.keys, self.values = v[0]
|
self.keys, self.values = v
|
||||||
self.offset = self.keys.shape[2]
|
self.offset = self.keys.shape[2]
|
||||||
|
|
||||||
|
|
||||||
@@ -236,20 +237,25 @@ class RotatingKVCache:
|
|||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
if self.offset < self.keys.shape[2]:
|
if self.offset < self.keys.shape[2]:
|
||||||
kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset])
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
else:
|
else:
|
||||||
kv_state = (self.keys, self.values)
|
return self.keys, self.values
|
||||||
extra_state = tuple(
|
|
||||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
|
||||||
)
|
|
||||||
return kv_state, extra_state
|
|
||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, v):
|
def state(self, v):
|
||||||
self.keys, self.values = v[0]
|
self.keys, self.values = v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_state(self):
|
||||||
|
return tuple(
|
||||||
|
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||||
|
)
|
||||||
|
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v):
|
||||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||||
int,
|
int,
|
||||||
v[1],
|
v,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -267,10 +273,6 @@ class MambaCache:
|
|||||||
def state(self):
|
def state(self):
|
||||||
return self.cache
|
return self.cache
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
return self.cache, ""
|
|
||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, v):
|
def state(self, v):
|
||||||
self.cache = v[0]
|
self.cache = v
|
||||||
|
@@ -240,7 +240,7 @@ def generate_step(
|
|||||||
|
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=cache)
|
model(y[:prefill_step_size][None], cache=cache)
|
||||||
mx.eval([c.state[0] for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
y = y[prefill_step_size:]
|
y = y[prefill_step_size:]
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
@@ -40,8 +40,8 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
self.assertTrue(len(cache), len(loaded_cache))
|
self.assertTrue(len(cache), len(loaded_cache))
|
||||||
for c, lc in zip(cache, loaded_cache):
|
for c, lc in zip(cache, loaded_cache):
|
||||||
self.assertEqual(c.offset, lc.offset)
|
self.assertEqual(c.offset, lc.offset)
|
||||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||||
|
|
||||||
# Test with metadata
|
# Test with metadata
|
||||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
@@ -67,8 +67,8 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
self.assertEqual(c.keep, lc.keep)
|
self.assertEqual(c.keep, lc.keep)
|
||||||
self.assertEqual(c.max_size, lc.max_size)
|
self.assertEqual(c.max_size, lc.max_size)
|
||||||
self.assertEqual(c.step, lc.step)
|
self.assertEqual(c.step, lc.step)
|
||||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||||
|
|
||||||
# Do a couple single token updates to get a rotation
|
# Do a couple single token updates to get a rotation
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
|
Reference in New Issue
Block a user