Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
ad16f41a7f Fix version tag (#2790)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-19 08:55:57 -08:00
Awni Hannun
f46877bc08 more accurate rope fallback (#2792) 2025-11-19 06:07:21 -08:00
3 changed files with 44 additions and 21 deletions

View File

@@ -416,23 +416,25 @@ array rope(
if (offset.size() > 1) { if (offset.size() > 1) {
offset = expand_dims(offset, {-1, -2}, s); offset = expand_dims(offset, {-1, -2}, s);
} }
auto positions = auto positions = multiply(
multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s); add(arange(x.shape(2), float32, s), offset, s),
array(scale, float32),
s);
auto default_inv_freqs = [&s, &t, base, half_dims]() { auto default_inv_freqs = [&s, base, half_dims]() {
return exp( return exp(
multiply( multiply(
arange(0, -half_dims, -1, t, s), arange(0, -half_dims, -1, float32, s),
array(std::log(base) / half_dims, t), array(std::log(base) / half_dims, float32),
s), s),
s); s);
}; };
auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) auto inv_freqs =
: default_inv_freqs(); inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs();
auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s); auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s);
auto coss = cos(theta, s); auto coss = astype(cos(theta, s), t, s);
auto sins = sin(theta, s); auto sins = astype(sin(theta, s), t, s);
auto apply_rope = [forward, s]( auto apply_rope = [forward, s](
const array& x1, const array& x1,

View File

@@ -332,6 +332,26 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rope_orig(x, dims, traditional, base, scale, offset) rx = rope_orig(x, dims, traditional, base, scale, offset)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5)
def test_rope_with_large_offset(self):
x = mx.random.normal(shape=(1, 1, 1024, 32))
rx_fp32 = mx.fast.rope(
x,
32,
traditional=False,
scale=1.0,
base=10000,
offset=4000,
)
rx_bf16 = mx.fast.rope(
x.astype(mx.bfloat16),
32,
traditional=False,
scale=1.0,
base=10000,
offset=4000,
)
self.assertLess((rx_fp32 - rx_bf16).abs().max(), 1e-1)
def test_rms_norm(self): def test_rms_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}

View File

@@ -24,21 +24,22 @@ def get_version():
if "#define MLX_VERSION_PATCH" in l: if "#define MLX_VERSION_PATCH" in l:
patch = l.split()[-1] patch = l.split()[-1]
version = f"{major}.{minor}.{patch}" version = f"{major}.{minor}.{patch}"
if os.environ.get("PYPI_RELEASE", False): pypi_release = os.environ.get("PYPI_RELEASE", False)
dev_release = os.environ.get("DEV_RELEASE", False)
if not pypi_release or dev_release:
today = datetime.date.today() today = datetime.date.today()
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release:
if os.environ.get("DEV_RELEASE", False): git_hash = (
git_hash = ( run(
run( "git rev-parse --short HEAD".split(),
"git rev-parse --short HEAD".split(), capture_output=True,
capture_output=True, check=True,
check=True,
)
.stdout.strip()
.decode()
) )
version = f"{version}+{git_hash}" .stdout.strip()
.decode()
)
version = f"{version}+{git_hash}"
return version return version