mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 12:26:07 +08:00
Remove unnecessary changes
This commit is contained in:
parent
72269c306c
commit
ebea6928a3
@ -26,10 +26,7 @@ def make_prompt_cache(
|
|||||||
if hasattr(model, "make_cache"):
|
if hasattr(model, "make_cache"):
|
||||||
return model.make_cache()
|
return model.make_cache()
|
||||||
|
|
||||||
if hasattr(model, "layers"):
|
|
||||||
num_layers = len(model.layers)
|
num_layers = len(model.layers)
|
||||||
else:
|
|
||||||
num_layers = len(model.model.layers)
|
|
||||||
if max_kv_size is not None:
|
if max_kv_size is not None:
|
||||||
return [
|
return [
|
||||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||||
|
@ -19,6 +19,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
List,
|
List,
|
||||||
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
@ -43,6 +44,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import cache
|
from .models import cache
|
||||||
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters, nparams
|
from .tuner.utils import load_adapters, nparams
|
||||||
@ -1048,7 +1050,6 @@ def convert(
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
dtype = getattr(mx, dtype)
|
dtype = getattr(mx, dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user