mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
fixed full dequantization mem leak
This commit is contained in:
parent
5b7581f41c
commit
231f5e870e
@ -5,20 +5,61 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
|
||||
"""
|
||||
Resizes model embeddings to accommodate new tokens
|
||||
Resizes model embeddings to accommodate new tokens, minimizing dequantization.
|
||||
"""
|
||||
old_embedding = model.model.embed_tokens
|
||||
|
||||
old_vocab_size = old_embedding.num_embeddings
|
||||
new_vocab_size = len(tokenizer._tokenizer)
|
||||
|
||||
if old_vocab_size != new_vocab_size:
|
||||
if new_vocab_size < old_vocab_size:
|
||||
print(
|
||||
"Warning: New vocab size is smaller than original. Proceeding with trim."
|
||||
if old_vocab_size == new_vocab_size:
|
||||
print("Vocab already sized right.")
|
||||
return model
|
||||
|
||||
if new_vocab_size < old_vocab_size:
|
||||
print("Warning: New vocab size is smaller than original. Proceeding with trim.")
|
||||
|
||||
if (
|
||||
hasattr(old_embedding, "weight")
|
||||
and hasattr(old_embedding, "scales")
|
||||
and hasattr(old_embedding, "biases")
|
||||
and hasattr(old_embedding, "group_size")
|
||||
and hasattr(old_embedding, "bits")
|
||||
):
|
||||
# quantized embedding case: minimize dequantization
|
||||
|
||||
new_embedding = nn.QuantizedEmbedding(
|
||||
new_vocab_size,
|
||||
old_embedding.dims,
|
||||
group_size=old_embedding.group_size,
|
||||
bits=old_embedding.bits,
|
||||
)
|
||||
if new_vocab_size > old_vocab_size:
|
||||
# Add new rows
|
||||
new_row_count = new_vocab_size - old_vocab_size
|
||||
new_rows = mx.random.normal((new_row_count, old_embedding.dims), scale=0.02)
|
||||
new_rows_q, new_rows_scales, new_rows_biases = mx.quantize(
|
||||
new_rows, old_embedding.group_size, old_embedding.bits
|
||||
)
|
||||
|
||||
# check if QuantizedEmbedding has required attributes for dequantization
|
||||
new_embedding.weight = mx.concatenate(
|
||||
[old_embedding.weight, new_rows_q], axis=0
|
||||
)
|
||||
new_embedding.scales = mx.concatenate(
|
||||
[old_embedding.scales, new_rows_scales], axis=0
|
||||
)
|
||||
new_embedding.biases = mx.concatenate(
|
||||
[old_embedding.biases, new_rows_biases], axis=0
|
||||
)
|
||||
|
||||
else: # new_vocab_size < old_vocab_size: Slice existing
|
||||
new_embedding.weight = old_embedding.weight[:new_vocab_size]
|
||||
new_embedding.scales = old_embedding.scales[:new_vocab_size]
|
||||
new_embedding.biases = old_embedding.biases[:new_vocab_size]
|
||||
|
||||
else:
|
||||
# non-quantized embedding case (fallback, less efficient)
|
||||
# dequantize ONLY if necessary
|
||||
# should ideally be avoided entirely for quantized models.
|
||||
try:
|
||||
dequantized_weights = mx.dequantize(
|
||||
old_embedding.weight,
|
||||
@ -27,14 +68,13 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul
|
||||
group_size=old_embedding.group_size,
|
||||
bits=old_embedding.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}")
|
||||
# handle missing quantization attributes
|
||||
except (AttributeError, TypeError):
|
||||
print("Falling back to random weights for embed_tokens.")
|
||||
dequantized_weights = mx.random.normal(
|
||||
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
|
||||
)
|
||||
|
||||
# resize embed_tokens
|
||||
new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
|
||||
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
|
||||
min_vocab_size = min(old_vocab_size, new_vocab_size)
|
||||
@ -46,88 +86,81 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul
|
||||
scale=0.02,
|
||||
)
|
||||
new_embedding.weight = new_weights
|
||||
model.model.embed_tokens = new_embedding
|
||||
|
||||
# attention layers handling
|
||||
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
|
||||
model.model.embed_tokens.weight = new_weights
|
||||
elif hasattr(model, "lm_head"):
|
||||
old_lm_head = model.lm_head
|
||||
if isinstance(old_lm_head, nn.QuantizedLinear):
|
||||
# resize nn.QuantizedLinear
|
||||
output_dims, compressed_input_dims = old_lm_head.weight.shape
|
||||
bits = old_lm_head.bits
|
||||
input_dims = compressed_input_dims * (32 // bits)
|
||||
model.model.embed_tokens = new_embedding
|
||||
|
||||
# dequantize lm_head weights
|
||||
try:
|
||||
dequantized_lm_weights = mx.dequantize(
|
||||
old_lm_head.weight,
|
||||
scales=old_lm_head.scales,
|
||||
biases=old_lm_head.biases,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
)
|
||||
except AttributeError as e:
|
||||
print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}")
|
||||
print("Falling back to random weights for lm_head.")
|
||||
dequantized_lm_weights = mx.random.normal(
|
||||
(output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
# handle lm_head
|
||||
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
|
||||
if hasattr(new_embedding, "weight") and not isinstance(
|
||||
new_embedding, nn.QuantizedEmbedding
|
||||
):
|
||||
model.model.embed_tokens.weight = new_embedding.weight
|
||||
|
||||
new_lm_head = nn.QuantizedLinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=new_vocab_size,
|
||||
bias="bias" in old_lm_head,
|
||||
group_size=old_lm_head.group_size,
|
||||
bits=old_lm_head.bits,
|
||||
elif hasattr(model, "lm_head"):
|
||||
old_lm_head = model.lm_head
|
||||
if isinstance(old_lm_head, nn.QuantizedLinear):
|
||||
output_dims, compressed_input_dims = old_lm_head.weight.shape
|
||||
bits = old_lm_head.bits
|
||||
input_dims = compressed_input_dims * (32 // bits)
|
||||
group_size = old_lm_head.group_size
|
||||
|
||||
new_lm_head = nn.QuantizedLinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=new_vocab_size,
|
||||
bias="bias" in old_lm_head,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
if new_vocab_size > old_vocab_size:
|
||||
new_row_count = new_vocab_size - old_vocab_size
|
||||
new_rows = mx.random.normal((new_row_count, input_dims), scale=0.02)
|
||||
new_rows_q, new_rows_scales, new_rows_biases = mx.quantize(
|
||||
new_rows, group_size, bits
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, input_dims))
|
||||
new_weights_lm[:min_vocab_size] = dequantized_lm_weights[
|
||||
:min_vocab_size
|
||||
]
|
||||
if new_vocab_size > output_dims:
|
||||
new_weights_lm[output_dims:] = mx.random.normal(
|
||||
(new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02
|
||||
)
|
||||
new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = (
|
||||
mx.quantize(
|
||||
new_weights_lm, new_lm_head.group_size, new_lm_head.bits
|
||||
)
|
||||
new_lm_head.weight = mx.concatenate(
|
||||
[old_lm_head.weight, new_rows_q], axis=0
|
||||
)
|
||||
new_lm_head.scales = mx.concatenate(
|
||||
[old_lm_head.scales, new_rows_scales], axis=0
|
||||
)
|
||||
new_lm_head.biases = mx.concatenate(
|
||||
[old_lm_head.biases, new_rows_biases], axis=0
|
||||
)
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
else:
|
||||
# resize nn.Linear
|
||||
new_lm_head = nn.Linear(
|
||||
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
|
||||
min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size)
|
||||
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
|
||||
if new_vocab_size > old_lm_head.weight.shape[0]:
|
||||
new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal(
|
||||
(
|
||||
new_vocab_size - old_lm_head.weight.shape[0],
|
||||
old_lm_head.input_dims,
|
||||
),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_lm_head.weight = new_weights_lm
|
||||
# todo typechecking
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[
|
||||
:min_vocab_size
|
||||
]
|
||||
new_lm_head.weight = old_lm_head.weight[:new_vocab_size]
|
||||
new_lm_head.scales = old_lm_head.scales[:new_vocab_size]
|
||||
new_lm_head.biases = old_lm_head.biases[:new_vocab_size]
|
||||
|
||||
if "bias" in old_lm_head:
|
||||
if new_vocab_size > old_vocab_size:
|
||||
new_bias = mx.concatenate(
|
||||
[old_lm_head.bias, mx.zeros(new_vocab_size - old_vocab_size)]
|
||||
)
|
||||
else:
|
||||
new_bias = old_lm_head.bias[:new_vocab_size]
|
||||
new_lm_head.bias = new_bias
|
||||
# nn.Linear case
|
||||
else:
|
||||
new_lm_head = nn.Linear(
|
||||
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
|
||||
)
|
||||
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
|
||||
min_vocab_size = min(old_vocab_size, new_vocab_size)
|
||||
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
|
||||
if new_vocab_size > old_vocab_size:
|
||||
new_weights_lm[old_vocab_size:] = mx.random.normal(
|
||||
(new_vocab_size - old_vocab_size, old_lm_head.input_dims),
|
||||
loc=0.0,
|
||||
scale=0.02,
|
||||
)
|
||||
new_lm_head.weight = new_weights_lm
|
||||
if "bias" in old_lm_head:
|
||||
new_lm_head.bias = mx.zeros((new_vocab_size,))
|
||||
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[:min_vocab_size]
|
||||
|
||||
model.lm_head = new_lm_head
|
||||
|
||||
model.lm_head = new_lm_head
|
||||
else:
|
||||
print("Vocab already sized right.")
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user