add f8_e4m3 loading (#1859)

This commit is contained in:
Alex Barron
2025-02-13 17:10:03 -08:00
committed by GitHub
parent 428f589364
commit 7f2d1024f3
2 changed files with 118 additions and 0 deletions

View File

@@ -128,6 +128,30 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.array_equal(load_dict["test"], save_dict["test"])
)
def test_load_f8_e4m3(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
expected = [
0,
mx.nan,
mx.nan,
-0.875,
0.4375,
-0.005859,
-1.25,
-1.25,
-1.5,
-0.0039,
]
expected = mx.array(expected, dtype=mx.bfloat16)
contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00\x7f\xff\xb6.\x83\xba\xba\xbc\x82'
with tempfile.NamedTemporaryFile(suffix=".safetensors") as f:
f.write(contents)
f.seek(0)
out = mx.load(f)["tensor"]
self.assertTrue(mx.allclose(out[0], expected[0], equal_nan=True))
def test_save_and_load_gguf_metadata_basic(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)