mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
switch statement
This commit is contained in:
parent
313f6bd9b1
commit
680f18cca5
@ -5,34 +5,33 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string dtype_to_safetensor_str(Dtype t) {
|
std::string dtype_to_safetensor_str(Dtype t) {
|
||||||
if (t == float32) {
|
switch (t) {
|
||||||
return ST_F32;
|
case float32:
|
||||||
} else if (t == bfloat16) {
|
return ST_F32;
|
||||||
return ST_BF16;
|
case bfloat16:
|
||||||
} else if (t == float16) {
|
return ST_BF16;
|
||||||
return ST_F16;
|
case float16:
|
||||||
} else if (t == int64) {
|
return ST_F16;
|
||||||
return ST_I64;
|
case int64:
|
||||||
} else if (t == int32) {
|
return ST_I64;
|
||||||
return ST_I32;
|
case int32:
|
||||||
} else if (t == int16) {
|
return ST_I32;
|
||||||
return ST_I16;
|
case int16:
|
||||||
} else if (t == int8) {
|
return ST_I16;
|
||||||
return ST_I8;
|
case int8:
|
||||||
} else if (t == uint64) {
|
return ST_I8;
|
||||||
return ST_U64;
|
case uint64:
|
||||||
} else if (t == uint32) {
|
return ST_U64;
|
||||||
return ST_U32;
|
case uint32:
|
||||||
} else if (t == uint16) {
|
return ST_U32;
|
||||||
return ST_U16;
|
case uint16:
|
||||||
} else if (t == uint8) {
|
return ST_U16;
|
||||||
return ST_U8;
|
case uint8:
|
||||||
} else if (t == bool_) {
|
return ST_U8;
|
||||||
return ST_BOOL;
|
case bool_:
|
||||||
} else if (t == complex64) {
|
return ST_BOOL;
|
||||||
return ST_C64;
|
case complex64:
|
||||||
} else {
|
return ST_C64;
|
||||||
throw std::runtime_error("[safetensor] unsupported dtype");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user