fix deepseek v2

This commit is contained in:
Awni Hannun
2024-08-19 16:09:17 -07:00
parent a3431ccc25
commit 3f8c1aca20

View File

@@ -68,11 +68,11 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(mn, mx, dim):
if mn == mx:
mx += 0.001 # Prevent singularity
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - mn) / (mx - mn)
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)