mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 03:31:13 +08:00
Remove duplicate defines of StreamOrDevice and is_big_endian (#892)
This commit is contained in:
parent
240d10699c
commit
a789685c63
@ -81,16 +81,6 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
||||
|
||||
// clang-format on
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||
|
@ -6,12 +6,10 @@
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "stream.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
array fftn(
|
||||
const array& a,
|
||||
|
@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = {
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian_() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
@ -94,7 +84,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint16_t v1_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 2);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 1, 1);
|
||||
@ -106,7 +96,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint32_t v2_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 4);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 3, 1);
|
||||
@ -219,7 +209,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
// Build primitive
|
||||
|
||||
size_t offset = 8 + header_len_size + header.length();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian();
|
||||
|
||||
if (col_contiguous) {
|
||||
std::reverse(shape.begin(), shape.end());
|
||||
|
10
mlx/utils.h
10
mlx/utils.h
@ -84,6 +84,16 @@ int check_shape_dim(const T dim) {
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
|
Loading…
Reference in New Issue
Block a user