diff --git a/llms/mlx_lm/tuner/new_tokens.py b/llms/mlx_lm/tuner/new_tokens.py index 981d4a56..dc735540 100644 --- a/llms/mlx_lm/tuner/new_tokens.py +++ b/llms/mlx_lm/tuner/new_tokens.py @@ -5,20 +5,61 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module: """ - Resizes model embeddings to accommodate new tokens + Resizes model embeddings to accommodate new tokens, minimizing dequantization. """ old_embedding = model.model.embed_tokens - old_vocab_size = old_embedding.num_embeddings new_vocab_size = len(tokenizer._tokenizer) - if old_vocab_size != new_vocab_size: - if new_vocab_size < old_vocab_size: - print( - "Warning: New vocab size is smaller than original. Proceeding with trim." + if old_vocab_size == new_vocab_size: + print("Vocab already sized right.") + return model + + if new_vocab_size < old_vocab_size: + print("Warning: New vocab size is smaller than original. Proceeding with trim.") + + if ( + hasattr(old_embedding, "weight") + and hasattr(old_embedding, "scales") + and hasattr(old_embedding, "biases") + and hasattr(old_embedding, "group_size") + and hasattr(old_embedding, "bits") + ): + # quantized embedding case: minimize dequantization + + new_embedding = nn.QuantizedEmbedding( + new_vocab_size, + old_embedding.dims, + group_size=old_embedding.group_size, + bits=old_embedding.bits, + ) + if new_vocab_size > old_vocab_size: + # Add new rows + new_row_count = new_vocab_size - old_vocab_size + new_rows = mx.random.normal((new_row_count, old_embedding.dims), scale=0.02) + new_rows_q, new_rows_scales, new_rows_biases = mx.quantize( + new_rows, old_embedding.group_size, old_embedding.bits ) - # check if QuantizedEmbedding has required attributes for dequantization + new_embedding.weight = mx.concatenate( + [old_embedding.weight, new_rows_q], axis=0 + ) + new_embedding.scales = mx.concatenate( + [old_embedding.scales, new_rows_scales], axis=0 + ) + new_embedding.biases = mx.concatenate( + [old_embedding.biases, new_rows_biases], axis=0 + ) + + else: # new_vocab_size < old_vocab_size: Slice existing + new_embedding.weight = old_embedding.weight[:new_vocab_size] + new_embedding.scales = old_embedding.scales[:new_vocab_size] + new_embedding.biases = old_embedding.biases[:new_vocab_size] + + else: + # non-quantized embedding case (fallback, less efficient) + # dequantize ONLY if necessary + # should ideally be avoided entirely for quantized models. try: dequantized_weights = mx.dequantize( old_embedding.weight, @@ -27,14 +68,13 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul group_size=old_embedding.group_size, bits=old_embedding.bits, ) - except AttributeError as e: - print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}") + # handle missing quantization attributes + except (AttributeError, TypeError): print("Falling back to random weights for embed_tokens.") dequantized_weights = mx.random.normal( (old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02 ) - # resize embed_tokens new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims) new_weights = mx.zeros((new_vocab_size, old_embedding.dims)) min_vocab_size = min(old_vocab_size, new_vocab_size) @@ -46,88 +86,81 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul scale=0.02, ) new_embedding.weight = new_weights - model.model.embed_tokens = new_embedding - # attention layers handling - if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): - model.model.embed_tokens.weight = new_weights - elif hasattr(model, "lm_head"): - old_lm_head = model.lm_head - if isinstance(old_lm_head, nn.QuantizedLinear): - # resize nn.QuantizedLinear - output_dims, compressed_input_dims = old_lm_head.weight.shape - bits = old_lm_head.bits - input_dims = compressed_input_dims * (32 // bits) + model.model.embed_tokens = new_embedding - # dequantize lm_head weights - try: - dequantized_lm_weights = mx.dequantize( - old_lm_head.weight, - scales=old_lm_head.scales, - biases=old_lm_head.biases, - group_size=old_lm_head.group_size, - bits=old_lm_head.bits, - ) - except AttributeError as e: - print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}") - print("Falling back to random weights for lm_head.") - dequantized_lm_weights = mx.random.normal( - (output_dims, input_dims), loc=0.0, scale=0.02 - ) + # handle lm_head + if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): + if hasattr(new_embedding, "weight") and not isinstance( + new_embedding, nn.QuantizedEmbedding + ): + model.model.embed_tokens.weight = new_embedding.weight - new_lm_head = nn.QuantizedLinear( - input_dims=input_dims, - output_dims=new_vocab_size, - bias="bias" in old_lm_head, - group_size=old_lm_head.group_size, - bits=old_lm_head.bits, + elif hasattr(model, "lm_head"): + old_lm_head = model.lm_head + if isinstance(old_lm_head, nn.QuantizedLinear): + output_dims, compressed_input_dims = old_lm_head.weight.shape + bits = old_lm_head.bits + input_dims = compressed_input_dims * (32 // bits) + group_size = old_lm_head.group_size + + new_lm_head = nn.QuantizedLinear( + input_dims=input_dims, + output_dims=new_vocab_size, + bias="bias" in old_lm_head, + group_size=group_size, + bits=bits, + ) + + if new_vocab_size > old_vocab_size: + new_row_count = new_vocab_size - old_vocab_size + new_rows = mx.random.normal((new_row_count, input_dims), scale=0.02) + new_rows_q, new_rows_scales, new_rows_biases = mx.quantize( + new_rows, group_size, bits ) - new_weights_lm = mx.zeros((new_vocab_size, input_dims)) - new_weights_lm[:min_vocab_size] = dequantized_lm_weights[ - :min_vocab_size - ] - if new_vocab_size > output_dims: - new_weights_lm[output_dims:] = mx.random.normal( - (new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02 - ) - new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = ( - mx.quantize( - new_weights_lm, new_lm_head.group_size, new_lm_head.bits - ) + new_lm_head.weight = mx.concatenate( + [old_lm_head.weight, new_rows_q], axis=0 + ) + new_lm_head.scales = mx.concatenate( + [old_lm_head.scales, new_rows_scales], axis=0 + ) + new_lm_head.biases = mx.concatenate( + [old_lm_head.biases, new_rows_biases], axis=0 ) - if "bias" in old_lm_head: - new_lm_head.bias = mx.zeros((new_vocab_size,)) - new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ - :min_vocab_size - ] else: - # resize nn.Linear - new_lm_head = nn.Linear( - old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head - ) - new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) - min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size) - new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] - if new_vocab_size > old_lm_head.weight.shape[0]: - new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal( - ( - new_vocab_size - old_lm_head.weight.shape[0], - old_lm_head.input_dims, - ), - loc=0.0, - scale=0.02, - ) - new_lm_head.weight = new_weights_lm - # todo typechecking - if "bias" in old_lm_head: - new_lm_head.bias = mx.zeros((new_vocab_size,)) - new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ - :min_vocab_size - ] + new_lm_head.weight = old_lm_head.weight[:new_vocab_size] + new_lm_head.scales = old_lm_head.scales[:new_vocab_size] + new_lm_head.biases = old_lm_head.biases[:new_vocab_size] + + if "bias" in old_lm_head: + if new_vocab_size > old_vocab_size: + new_bias = mx.concatenate( + [old_lm_head.bias, mx.zeros(new_vocab_size - old_vocab_size)] + ) + else: + new_bias = old_lm_head.bias[:new_vocab_size] + new_lm_head.bias = new_bias + # nn.Linear case + else: + new_lm_head = nn.Linear( + old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head + ) + new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) + min_vocab_size = min(old_vocab_size, new_vocab_size) + new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] + if new_vocab_size > old_vocab_size: + new_weights_lm[old_vocab_size:] = mx.random.normal( + (new_vocab_size - old_vocab_size, old_lm_head.input_dims), + loc=0.0, + scale=0.02, + ) + new_lm_head.weight = new_weights_lm + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[:min_vocab_size] + + model.lm_head = new_lm_head - model.lm_head = new_lm_head - else: - print("Vocab already sized right.") return model