feat: add repetition_penalty

This commit is contained in:
Anchen
2023-12-24 22:38:50 +11:00
committed by Awni Hannun
parent bd63a3e5ee
commit a476d1909d

View File

@@ -214,7 +214,23 @@ class DeepseekCoder(nn.Module):
return self.output(x), cache
def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
def apply_repeat_penalty(logits, context, penalty):
if len(context) > 0:
indices = mx.array([token.item() for token in context])
selected_logists = logits[:, indices]
selected_logists = mx.where(
selected_logists < 0, selected_logists * penalty, selected_logists / penalty
)
logits[:, indices] = selected_logists
def generate(
prompt: mx.array,
model: DeepseekCoder,
temp: 0.0,
generated_tokens,
repetition_penalty,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
@@ -227,7 +243,10 @@ def generate(prompt: mx.array, model: DeepseekCoder, temp: 0.0):
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
logits = logits.squeeze(1)
if repetition_penalty is not None and repetition_penalty != 1.0:
apply_repeat_penalty(logits, generated_tokens, repetition_penalty)
y = sample(logits)
yield y
@@ -284,6 +303,14 @@ if __name__ == "__main__":
type=float,
default=0.6,
)
parser.add_argument(
"--repetition-penalty",
help="The parameter for repetition penalty.",
type=float,
default=1.2,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
@@ -302,7 +329,10 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
for token, _ in zip(
generate(prompt, model, args.temp, tokens, args.repetition_penalty),
range(args.max_tokens),
):
tokens.append(token)
if (len(tokens) % 10) == 0: