From 7f2d1024f3077871918037204c07273b8783e6a2 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 13 Feb 2025 17:10:03 -0800 Subject: [PATCH] add f8_e4m3 loading (#1859) --- mlx/io/safetensors.cpp | 94 +++++++++++++++++++++++++++++++++++++++ python/tests/test_load.py | 24 ++++++++++ 2 files changed, 118 insertions(+) diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 825e033043..f7c9103729 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/fast.h" #include "mlx/io.h" #include "mlx/io/load.h" #include "mlx/ops.h" @@ -25,6 +26,7 @@ using json = nlohmann::json; #define ST_U16 "U16" #define ST_U32 "U32" #define ST_U64 "U64" +#define ST_F8_E4M3 "F8_E4M3" // Note: Complex numbers aren't in the spec yet so this could change - // https://github.com/huggingface/safetensors/issues/389 @@ -92,12 +94,101 @@ Dtype dtype_from_safetensor_str(std::string_view str) { return bool_; } else if (str == ST_C64) { return complex64; + } else if (str == ST_F8_E4M3) { + // We convert this manually later + return uint8; } else { throw std::runtime_error( "[safetensor] unsupported dtype " + std::string(str)); } } +array f8_e4m3_to_float(array x, Dtype dtype, StreamOrDevice s) { + if (to_stream(s).device == Device::gpu) { + // From PyTorch: + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 + std::string source = R"( + uint elem = thread_position_in_grid.x; + uint8_t val = x[elem]; + + const uint32_t w = (uint32_t)val << 24; + const uint32_t sign = w & 0x80000000; + const uint32_t nonsign = w & 0x7FFFFFFF; + + uint32_t renorm_shift = metal::clz(nonsign); + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & 0x7F800000; + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + + float out = *(reinterpret_cast(&result)); + y[elem] = static_cast(out); + )"; + auto kernel = fast::metal_kernel("f8_e4m3", {"x"}, {"y"}, source); + auto outputs = kernel( + {x}, + {x.shape()}, + {dtype}, + {x.size(), 1, 1}, + {256, 1, 1}, + {{"T", dtype}}, + std::nullopt, + false, + s); + return outputs[0]; + } else { + auto w = left_shift(astype(x, uint32, s), array({24}, uint32), s); + auto sign = bitwise_and(w, array({0x80000000}, uint32), s); + auto nonsign = bitwise_and(w, array({0x7FFFFFFF}, uint32), s); + + // Emulate a clz op with a lookup table + auto clz_table = + array({28, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, uint32); + auto renorm_shift = take(clz_table, bitwise_and(x, array({0xf}), s), s); + renorm_shift = where( + greater( + bitwise_and(x, array({0x70}, uint32), s), array({0}, uint32), s), + array({0}, uint32), + renorm_shift, + s); + auto inf_nan_mask = bitwise_and( + right_shift( + astype(add(nonsign, array(0x01000000, int32), s), int32, s), + array({8}, int32), + s), + array({0x7F800000}, int32), + s); + auto zero_mask = right_shift( + astype(subtract(nonsign, array({1}, uint32), s), int32, s), + array({31}, int32), + s); + zero_mask = astype(zero_mask, uint32, s); + inf_nan_mask = astype(inf_nan_mask, uint32, s); + auto result = + add(right_shift( + left_shift(nonsign, renorm_shift, s), array({4}, uint32), s), + left_shift( + subtract(array({0x78}, uint32), renorm_shift, s), + array({23}, uint32), + s), + s); + result = bitwise_or( + sign, + bitwise_and( + bitwise_or(result, inf_nan_mask, s), + bitwise_invert(zero_mask, s), + s), + s); + result = astype(view(result, float32, s), dtype, s); + return result; + } +} + /** Load array from reader in safetensor format */ SafetensorsLoad load_safetensors( std::shared_ptr in_stream, @@ -147,6 +238,9 @@ SafetensorsLoad load_safetensors( std::make_shared( to_stream(s), in_stream, offset + data_offsets.at(0), false), std::vector{}); + if (dtype == ST_F8_E4M3) { + loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s); + } res.insert({item.key(), loaded_array}); } return {res, metadata_map}; diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 0b42e4a6b2..335d6ea944 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -128,6 +128,30 @@ class TestLoad(mlx_tests.MLXTestCase): mx.array_equal(load_dict["test"], save_dict["test"]) ) + def test_load_f8_e4m3(self): + if not os.path.isdir(self.test_dir): + os.mkdir(self.test_dir) + + expected = [ + 0, + mx.nan, + mx.nan, + -0.875, + 0.4375, + -0.005859, + -1.25, + -1.25, + -1.5, + -0.0039, + ] + expected = mx.array(expected, dtype=mx.bfloat16) + contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00\x7f\xff\xb6.\x83\xba\xba\xbc\x82' + with tempfile.NamedTemporaryFile(suffix=".safetensors") as f: + f.write(contents) + f.seek(0) + out = mx.load(f)["tensor"] + self.assertTrue(mx.allclose(out[0], expected[0], equal_nan=True)) + def test_save_and_load_gguf_metadata_basic(self): if not os.path.isdir(self.test_dir): os.mkdir(self.test_dir)