comments + docs

This commit is contained in:
Awni Hannun
2024-10-07 13:16:58 -07:00
parent 52ffc2f477
commit f6ff4f28b4
6 changed files with 60 additions and 31 deletions

View File

@@ -20,6 +20,31 @@ The `mlx-lm` package also has:
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
@@ -138,7 +163,7 @@ mlx_lm.convert \
### Long Prompts and Generations
MLX LM has some tools to scale efficiently to long prompts and generations:
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
@@ -178,7 +203,7 @@ for more usage details.
### Supported Models
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.

View File

@@ -56,7 +56,7 @@ def main():
tokenizer_config={"trust_remote_code": True},
)
print(f"Starting chat sessiong with {args.model}. To exit, enter 'q'.")
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")

View File

@@ -14,6 +14,7 @@ DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def str2bool(string):
@@ -27,6 +28,7 @@ def setup_arg_parser():
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",

View File

@@ -44,12 +44,11 @@ def save_prompt_cache(
metadata (Optional[Dict[str, str]]): Optional metadata to save along
with model state..
"""
cache_data, cache_info = zip(*(c.state for c in cache))
cache_data = [c.state for c in cache]
cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_classes, cache_info]
if metadata:
cache_metadata.append(metadata)
cache_metadata = [cache_classes, cache_info, metadata or ""]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
@@ -70,12 +69,14 @@ def load_prompt_cache(file_name, return_metadata=False):
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
classes, info = cache_metadata[:2]
classes, info, metadata = cache_metadata
cache = [globals()[c]() for c in classes]
for c, *state in zip(cache, arrays, info):
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
if hasattr(c, "meta_state"):
c.meta_state = meta_state
if return_metadata:
return cache, cache_metadata[2]
return cache, metadata
return cache
@@ -114,16 +115,16 @@ class KVCache:
@property
def state(self):
if self.offset == self.keys.shape[2]:
return (self.keys, self.values), ""
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
), ""
)
@state.setter
def state(self, v):
self.keys, self.values = v[0]
self.keys, self.values = v
self.offset = self.keys.shape[2]
@@ -236,20 +237,25 @@ class RotatingKVCache:
@property
def state(self):
if self.offset < self.keys.shape[2]:
kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset])
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
kv_state = (self.keys, self.values)
extra_state = tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
return kv_state, extra_state
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v[0]
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v[1],
v,
)
@@ -267,10 +273,6 @@ class MambaCache:
def state(self):
return self.cache
@property
def state(self):
return self.cache, ""
@state.setter
def state(self, v):
self.cache = v[0]
self.cache = v

View File

@@ -240,7 +240,7 @@ def generate_step(
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state[0] for c in cache])
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)

View File

@@ -40,8 +40,8 @@ class TestPromptCache(unittest.TestCase):
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
@@ -67,8 +67,8 @@ class TestPromptCache(unittest.TestCase):
self.assertEqual(c.keep, lc.keep)
self.assertEqual(c.max_size, lc.max_size)
self.assertEqual(c.step, lc.step)
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Do a couple single token updates to get a rotation
for _ in range(2):