mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
Remove duplicate defines of StreamOrDevice and is_big_endian (#892)
This commit is contained in:
parent
240d10699c
commit
a789685c63
@ -81,16 +81,6 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
|||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
inline bool is_big_endian() {
|
|
||||||
union ByteOrder {
|
|
||||||
int32_t i;
|
|
||||||
uint8_t c[4];
|
|
||||||
};
|
|
||||||
ByteOrder b = {0x01234567};
|
|
||||||
|
|
||||||
return b.c[0] == 0x01;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||||
|
@ -6,12 +6,10 @@
|
|||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "stream.h"
|
#include "utils.h"
|
||||||
|
|
||||||
namespace mlx::core::fft {
|
namespace mlx::core::fft {
|
||||||
|
|
||||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
|
||||||
|
|
||||||
/** Compute the n-dimensional Fourier Transform. */
|
/** Compute the n-dimensional Fourier Transform. */
|
||||||
array fftn(
|
array fftn(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -26,16 +26,6 @@ constexpr uint8_t MAGIC[] = {
|
|||||||
0x59,
|
0x59,
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool is_big_endian_() {
|
|
||||||
union ByteOrder {
|
|
||||||
int32_t i;
|
|
||||||
uint8_t c[4];
|
|
||||||
};
|
|
||||||
ByteOrder b = {0x01234567};
|
|
||||||
|
|
||||||
return b.c[0] == 0x01;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
@ -94,7 +84,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
|||||||
uint16_t v1_header_len = header.tellp();
|
uint16_t v1_header_len = header.tellp();
|
||||||
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||||
|
|
||||||
if (!is_big_endian_()) {
|
if (!is_big_endian()) {
|
||||||
magic_ver_len.write(len_bytes, 2);
|
magic_ver_len.write(len_bytes, 2);
|
||||||
} else {
|
} else {
|
||||||
magic_ver_len.write(len_bytes + 1, 1);
|
magic_ver_len.write(len_bytes + 1, 1);
|
||||||
@ -106,7 +96,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
|||||||
uint32_t v2_header_len = header.tellp();
|
uint32_t v2_header_len = header.tellp();
|
||||||
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||||
|
|
||||||
if (!is_big_endian_()) {
|
if (!is_big_endian()) {
|
||||||
magic_ver_len.write(len_bytes, 4);
|
magic_ver_len.write(len_bytes, 4);
|
||||||
} else {
|
} else {
|
||||||
magic_ver_len.write(len_bytes + 3, 1);
|
magic_ver_len.write(len_bytes + 3, 1);
|
||||||
@ -219,7 +209,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
|||||||
// Build primitive
|
// Build primitive
|
||||||
|
|
||||||
size_t offset = 8 + header_len_size + header.length();
|
size_t offset = 8 + header_len_size + header.length();
|
||||||
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
bool swap_endianness = read_is_big_endian != is_big_endian();
|
||||||
|
|
||||||
if (col_contiguous) {
|
if (col_contiguous) {
|
||||||
std::reverse(shape.begin(), shape.end());
|
std::reverse(shape.begin(), shape.end());
|
||||||
|
10
mlx/utils.h
10
mlx/utils.h
@ -84,6 +84,16 @@ int check_shape_dim(const T dim) {
|
|||||||
return static_cast<int>(dim);
|
return static_cast<int>(dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_big_endian() {
|
||||||
|
union ByteOrder {
|
||||||
|
int32_t i;
|
||||||
|
uint8_t c[4];
|
||||||
|
};
|
||||||
|
ByteOrder b = {0x01234567};
|
||||||
|
|
||||||
|
return b.c[0] == 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the axis normalized to be in the range [0, ndim).
|
* Returns the axis normalized to be in the range [0, ndim).
|
||||||
* Based on numpy's normalize_axis_index. See
|
* Based on numpy's normalize_axis_index. See
|
||||||
|
Loading…
Reference in New Issue
Block a user