diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 10efb88f9..5eaf7c90e 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -81,16 +81,6 @@ constexpr Dtype::Category type_to_category[num_types] = { // clang-format on -inline bool is_big_endian() { - union ByteOrder { - int32_t i; - uint8_t c[4]; - }; - ByteOrder b = {0x01234567}; - - return b.c[0] == 0x01; -} - } // namespace Dtype promote_types(const Dtype& t1, const Dtype& t2) { diff --git a/mlx/fft.h b/mlx/fft.h index dbcc777fe..06298f821 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -6,12 +6,10 @@ #include "array.h" #include "device.h" -#include "stream.h" +#include "utils.h" namespace mlx::core::fft { -using StreamOrDevice = std::variant; - /** Compute the n-dimensional Fourier Transform. */ array fftn( const array& a, diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 3d27ab04e..c9c618fe9 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = { 0x59, }; -inline bool is_big_endian_() { - union ByteOrder { - int32_t i; - uint8_t c[4]; - }; - ByteOrder b = {0x01234567}; - - return b.c[0] == 0x01; -} - } // namespace /** Save array to out stream in .npy format */ @@ -94,7 +84,7 @@ void save(std::shared_ptr out_stream, array a) { uint16_t v1_header_len = header.tellp(); const char* len_bytes = reinterpret_cast(&v1_header_len); - if (!is_big_endian_()) { + if (!is_big_endian()) { magic_ver_len.write(len_bytes, 2); } else { magic_ver_len.write(len_bytes + 1, 1); @@ -106,7 +96,7 @@ void save(std::shared_ptr out_stream, array a) { uint32_t v2_header_len = header.tellp(); const char* len_bytes = reinterpret_cast(&v2_header_len); - if (!is_big_endian_()) { + if (!is_big_endian()) { magic_ver_len.write(len_bytes, 4); } else { magic_ver_len.write(len_bytes + 3, 1); @@ -219,7 +209,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { // Build primitive size_t offset = 8 + header_len_size + header.length(); - bool swap_endianness = read_is_big_endian != is_big_endian_(); + bool swap_endianness = read_is_big_endian != is_big_endian(); if (col_contiguous) { std::reverse(shape.begin(), shape.end()); diff --git a/mlx/utils.h b/mlx/utils.h index 1fe7edc05..a86db4009 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -84,6 +84,16 @@ int check_shape_dim(const T dim) { return static_cast(dim); } +inline bool is_big_endian() { + union ByteOrder { + int32_t i; + uint8_t c[4]; + }; + ByteOrder b = {0x01234567}; + + return b.c[0] == 0x01; +} + /** * Returns the axis normalized to be in the range [0, ndim). * Based on numpy's normalize_axis_index. See