mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 16:56:46 +08:00
add f8_e4m3 loading (#1859)
This commit is contained in:
parent
428f589364
commit
7f2d1024f3
@ -4,6 +4,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
|
|
||||||
|
#include "mlx/fast.h"
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
#include "mlx/io/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -25,6 +26,7 @@ using json = nlohmann::json;
|
|||||||
#define ST_U16 "U16"
|
#define ST_U16 "U16"
|
||||||
#define ST_U32 "U32"
|
#define ST_U32 "U32"
|
||||||
#define ST_U64 "U64"
|
#define ST_U64 "U64"
|
||||||
|
#define ST_F8_E4M3 "F8_E4M3"
|
||||||
|
|
||||||
// Note: Complex numbers aren't in the spec yet so this could change -
|
// Note: Complex numbers aren't in the spec yet so this could change -
|
||||||
// https://github.com/huggingface/safetensors/issues/389
|
// https://github.com/huggingface/safetensors/issues/389
|
||||||
@ -92,12 +94,101 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
|
|||||||
return bool_;
|
return bool_;
|
||||||
} else if (str == ST_C64) {
|
} else if (str == ST_C64) {
|
||||||
return complex64;
|
return complex64;
|
||||||
|
} else if (str == ST_F8_E4M3) {
|
||||||
|
// We convert this manually later
|
||||||
|
return uint8;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[safetensor] unsupported dtype " + std::string(str));
|
"[safetensor] unsupported dtype " + std::string(str));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array f8_e4m3_to_float(array x, Dtype dtype, StreamOrDevice s) {
|
||||||
|
if (to_stream(s).device == Device::gpu) {
|
||||||
|
// From PyTorch:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46
|
||||||
|
std::string source = R"(
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
uint8_t val = x[elem];
|
||||||
|
|
||||||
|
const uint32_t w = (uint32_t)val << 24;
|
||||||
|
const uint32_t sign = w & 0x80000000;
|
||||||
|
const uint32_t nonsign = w & 0x7FFFFFFF;
|
||||||
|
|
||||||
|
uint32_t renorm_shift = metal::clz(nonsign);
|
||||||
|
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||||
|
|
||||||
|
const int32_t inf_nan_mask =
|
||||||
|
((int32_t)(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
||||||
|
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||||
|
uint32_t result = sign |
|
||||||
|
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||||
|
inf_nan_mask) &
|
||||||
|
~zero_mask);
|
||||||
|
|
||||||
|
float out = *(reinterpret_cast<thread float*>(&result));
|
||||||
|
y[elem] = static_cast<T>(out);
|
||||||
|
)";
|
||||||
|
auto kernel = fast::metal_kernel("f8_e4m3", {"x"}, {"y"}, source);
|
||||||
|
auto outputs = kernel(
|
||||||
|
{x},
|
||||||
|
{x.shape()},
|
||||||
|
{dtype},
|
||||||
|
{x.size(), 1, 1},
|
||||||
|
{256, 1, 1},
|
||||||
|
{{"T", dtype}},
|
||||||
|
std::nullopt,
|
||||||
|
false,
|
||||||
|
s);
|
||||||
|
return outputs[0];
|
||||||
|
} else {
|
||||||
|
auto w = left_shift(astype(x, uint32, s), array({24}, uint32), s);
|
||||||
|
auto sign = bitwise_and(w, array({0x80000000}, uint32), s);
|
||||||
|
auto nonsign = bitwise_and(w, array({0x7FFFFFFF}, uint32), s);
|
||||||
|
|
||||||
|
// Emulate a clz op with a lookup table
|
||||||
|
auto clz_table =
|
||||||
|
array({28, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, uint32);
|
||||||
|
auto renorm_shift = take(clz_table, bitwise_and(x, array({0xf}), s), s);
|
||||||
|
renorm_shift = where(
|
||||||
|
greater(
|
||||||
|
bitwise_and(x, array({0x70}, uint32), s), array({0}, uint32), s),
|
||||||
|
array({0}, uint32),
|
||||||
|
renorm_shift,
|
||||||
|
s);
|
||||||
|
auto inf_nan_mask = bitwise_and(
|
||||||
|
right_shift(
|
||||||
|
astype(add(nonsign, array(0x01000000, int32), s), int32, s),
|
||||||
|
array({8}, int32),
|
||||||
|
s),
|
||||||
|
array({0x7F800000}, int32),
|
||||||
|
s);
|
||||||
|
auto zero_mask = right_shift(
|
||||||
|
astype(subtract(nonsign, array({1}, uint32), s), int32, s),
|
||||||
|
array({31}, int32),
|
||||||
|
s);
|
||||||
|
zero_mask = astype(zero_mask, uint32, s);
|
||||||
|
inf_nan_mask = astype(inf_nan_mask, uint32, s);
|
||||||
|
auto result =
|
||||||
|
add(right_shift(
|
||||||
|
left_shift(nonsign, renorm_shift, s), array({4}, uint32), s),
|
||||||
|
left_shift(
|
||||||
|
subtract(array({0x78}, uint32), renorm_shift, s),
|
||||||
|
array({23}, uint32),
|
||||||
|
s),
|
||||||
|
s);
|
||||||
|
result = bitwise_or(
|
||||||
|
sign,
|
||||||
|
bitwise_and(
|
||||||
|
bitwise_or(result, inf_nan_mask, s),
|
||||||
|
bitwise_invert(zero_mask, s),
|
||||||
|
s),
|
||||||
|
s);
|
||||||
|
result = astype(view(result, float32, s), dtype, s);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Load array from reader in safetensor format */
|
/** Load array from reader in safetensor format */
|
||||||
SafetensorsLoad load_safetensors(
|
SafetensorsLoad load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
@ -147,6 +238,9 @@ SafetensorsLoad load_safetensors(
|
|||||||
std::make_shared<Load>(
|
std::make_shared<Load>(
|
||||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
|
if (dtype == ST_F8_E4M3) {
|
||||||
|
loaded_array = f8_e4m3_to_float(loaded_array, bfloat16, s);
|
||||||
|
}
|
||||||
res.insert({item.key(), loaded_array});
|
res.insert({item.key(), loaded_array});
|
||||||
}
|
}
|
||||||
return {res, metadata_map};
|
return {res, metadata_map};
|
||||||
|
@ -128,6 +128,30 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_load_f8_e4m3(self):
|
||||||
|
if not os.path.isdir(self.test_dir):
|
||||||
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
0,
|
||||||
|
mx.nan,
|
||||||
|
mx.nan,
|
||||||
|
-0.875,
|
||||||
|
0.4375,
|
||||||
|
-0.005859,
|
||||||
|
-1.25,
|
||||||
|
-1.25,
|
||||||
|
-1.5,
|
||||||
|
-0.0039,
|
||||||
|
]
|
||||||
|
expected = mx.array(expected, dtype=mx.bfloat16)
|
||||||
|
contents = b'H\x00\x00\x00\x00\x00\x00\x00{"tensor":{"dtype":"F8_E4M3","shape":[10],"data_offsets":[0,10]}} \x00\x7f\xff\xb6.\x83\xba\xba\xbc\x82'
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".safetensors") as f:
|
||||||
|
f.write(contents)
|
||||||
|
f.seek(0)
|
||||||
|
out = mx.load(f)["tensor"]
|
||||||
|
self.assertTrue(mx.allclose(out[0], expected[0], equal_nan=True))
|
||||||
|
|
||||||
def test_save_and_load_gguf_metadata_basic(self):
|
def test_save_and_load_gguf_metadata_basic(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user