mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-12 15:25:00 +08:00
feat: add repetition_penalty
This commit is contained in:
@@ -214,7 +214,23 @@ class DeepseekCoder(nn.Module):
|
|||||||
return self.output(x), cache
|
return self.output(x), cache
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
|
def apply_repeat_penalty(logits, context, penalty):
|
||||||
|
if len(context) > 0:
|
||||||
|
indices = mx.array([token.item() for token in context])
|
||||||
|
selected_logists = logits[:, indices]
|
||||||
|
selected_logists = mx.where(
|
||||||
|
selected_logists < 0, selected_logists * penalty, selected_logists / penalty
|
||||||
|
)
|
||||||
|
logits[:, indices] = selected_logists
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
prompt: mx.array,
|
||||||
|
model: DeepseekCoder,
|
||||||
|
temp: 0.0,
|
||||||
|
generated_tokens,
|
||||||
|
repetition_penalty,
|
||||||
|
):
|
||||||
def sample(logits):
|
def sample(logits):
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
return mx.argmax(logits, axis=-1)
|
return mx.argmax(logits, axis=-1)
|
||||||
@@ -227,7 +243,10 @@ def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
logits, cache = model(y[:, None], cache=cache)
|
logits, cache = model(y[:, None], cache=cache)
|
||||||
y = sample(logits.squeeze(1))
|
logits = logits.squeeze(1)
|
||||||
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||||
|
apply_repeat_penalty(logits, generated_tokens, repetition_penalty)
|
||||||
|
y = sample(logits)
|
||||||
yield y
|
yield y
|
||||||
|
|
||||||
|
|
||||||
@@ -284,6 +303,14 @@ if __name__ == "__main__":
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.6,
|
default=0.6,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repetition-penalty",
|
||||||
|
help="The parameter for repetition penalty.",
|
||||||
|
type=float,
|
||||||
|
default=1.2,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -302,7 +329,10 @@ if __name__ == "__main__":
|
|||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
for token, _ in zip(
|
||||||
|
generate(prompt, model, args.temp, tokens, args.repetition_penalty),
|
||||||
|
range(args.max_tokens),
|
||||||
|
):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
|
Reference in New Issue
Block a user