mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
feat: add repetition_penalty
This commit is contained in:
@@ -214,7 +214,23 @@ class DeepseekCoder(nn.Module):
|
||||
return self.output(x), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
|
||||
def apply_repeat_penalty(logits, context, penalty):
|
||||
if len(context) > 0:
|
||||
indices = mx.array([token.item() for token in context])
|
||||
selected_logists = logits[:, indices]
|
||||
selected_logists = mx.where(
|
||||
selected_logists < 0, selected_logists * penalty, selected_logists / penalty
|
||||
)
|
||||
logits[:, indices] = selected_logists
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: mx.array,
|
||||
model: DeepseekCoder,
|
||||
temp: 0.0,
|
||||
generated_tokens,
|
||||
repetition_penalty,
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
@@ -227,7 +243,10 @@ def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache=cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
logits = logits.squeeze(1)
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
apply_repeat_penalty(logits, generated_tokens, repetition_penalty)
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
||||
|
||||
@@ -284,6 +303,14 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=0.6,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repetition-penalty",
|
||||
help="The parameter for repetition penalty.",
|
||||
type=float,
|
||||
default=1.2,
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -302,7 +329,10 @@ if __name__ == "__main__":
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
for token, _ in zip(
|
||||
generate(prompt, model, args.temp, tokens, args.repetition_penalty),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
|
Reference in New Issue
Block a user