mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix test tolerance and patch bump (#1315)
This commit is contained in:
parent
eb8819e91e
commit
780c197f95
@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.16.1)
|
set(MLX_VERSION 0.16.2)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
@ -357,9 +357,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
|
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
|
||||||
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
|
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
|
||||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
self.assertLess(mx.abs(gx1 - gx2).max(), 5e-5)
|
||||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)
|
||||||
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
|
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 5e-5)
|
||||||
|
|
||||||
def gf(f):
|
def gf(f):
|
||||||
def inner(x, w, b, y):
|
def inner(x, w, b, y):
|
||||||
@ -370,8 +370,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
|
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
|
||||||
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
|
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
|
||||||
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 1e-5)
|
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 5e-5)
|
||||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)
|
||||||
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
||||||
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -163,7 +163,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
version=get_version("0.16.1"),
|
version=get_version("0.16.2"),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
|
Loading…
Reference in New Issue
Block a user