mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes * some cleanup * add utils * fix * fix 2 * more cleaning * fix * remove unused mps matmul support * one more nit * revert
This commit is contained in:
@@ -26,6 +26,80 @@ constexpr uint8_t MAGIC[] = {
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
if (size_of(t) > 1) {
|
||||
r << (is_big_endian() ? ">" : "<");
|
||||
} else {
|
||||
r << "|";
|
||||
}
|
||||
r << kindof(t) << (int)size_of(t);
|
||||
return r.str();
|
||||
}
|
||||
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||
if (t.length() == 2 || t.length() == 3) {
|
||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
|
||||
if (r == "V2") {
|
||||
return bfloat16;
|
||||
}
|
||||
|
||||
uint8_t size = r[1] - '0';
|
||||
|
||||
switch (r[0]) {
|
||||
case 'b': {
|
||||
if (size == 1)
|
||||
return bool_;
|
||||
}
|
||||
case 'i': {
|
||||
if (size == 1)
|
||||
return int8;
|
||||
else if (size == 2)
|
||||
return int16;
|
||||
else if (size == 4)
|
||||
return int32;
|
||||
else if (size == 8)
|
||||
return int64;
|
||||
}
|
||||
case 'u': {
|
||||
if (size == 1)
|
||||
return uint8;
|
||||
else if (size == 2)
|
||||
return uint16;
|
||||
else if (size == 4)
|
||||
return uint32;
|
||||
else if (size == 8)
|
||||
return uint64;
|
||||
}
|
||||
case 'f': {
|
||||
if (size == 2)
|
||||
return float16;
|
||||
else if (size == 4)
|
||||
return float32;
|
||||
}
|
||||
case 'c': {
|
||||
return complex64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
|
||||
Reference in New Issue
Block a user