Remove duplicate defines of StreamOrDevice and is_big_endian (#892)

This commit is contained in:
Cheng 2024-03-27 07:15:11 +09:00 committed by GitHub
parent 240d10699c
commit a789685c63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 26 deletions

View File

@ -81,16 +81,6 @@ constexpr Dtype::Category type_to_category[num_types] = {
// clang-format on
inline bool is_big_endian() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
Dtype promote_types(const Dtype& t1, const Dtype& t2) {

View File

@ -6,12 +6,10 @@
#include "array.h"
#include "device.h"
#include "stream.h"
#include "utils.h"
namespace mlx::core::fft {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
/** Compute the n-dimensional Fourier Transform. */
array fftn(
const array& a,

View File

@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = {
0x59,
};
inline bool is_big_endian_() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
/** Save array to out stream in .npy format */
@ -94,7 +84,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
uint16_t v1_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
if (!is_big_endian_()) {
if (!is_big_endian()) {
magic_ver_len.write(len_bytes, 2);
} else {
magic_ver_len.write(len_bytes + 1, 1);
@ -106,7 +96,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
uint32_t v2_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
if (!is_big_endian_()) {
if (!is_big_endian()) {
magic_ver_len.write(len_bytes, 4);
} else {
magic_ver_len.write(len_bytes + 3, 1);
@ -219,7 +209,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
// Build primitive
size_t offset = 8 + header_len_size + header.length();
bool swap_endianness = read_is_big_endian != is_big_endian_();
bool swap_endianness = read_is_big_endian != is_big_endian();
if (col_contiguous) {
std::reverse(shape.begin(), shape.end());

View File

@ -84,6 +84,16 @@ int check_shape_dim(const T dim) {
return static_cast<int>(dim);
}
inline bool is_big_endian() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
/**
* Returns the axis normalized to be in the range [0, ndim).
* Based on numpy's normalize_axis_index. See