mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
switch statement
This commit is contained in:
parent
313f6bd9b1
commit
680f18cca5
@ -5,34 +5,33 @@
|
||||
namespace mlx::core {
|
||||
|
||||
std::string dtype_to_safetensor_str(Dtype t) {
|
||||
if (t == float32) {
|
||||
switch (t) {
|
||||
case float32:
|
||||
return ST_F32;
|
||||
} else if (t == bfloat16) {
|
||||
case bfloat16:
|
||||
return ST_BF16;
|
||||
} else if (t == float16) {
|
||||
case float16:
|
||||
return ST_F16;
|
||||
} else if (t == int64) {
|
||||
case int64:
|
||||
return ST_I64;
|
||||
} else if (t == int32) {
|
||||
case int32:
|
||||
return ST_I32;
|
||||
} else if (t == int16) {
|
||||
case int16:
|
||||
return ST_I16;
|
||||
} else if (t == int8) {
|
||||
case int8:
|
||||
return ST_I8;
|
||||
} else if (t == uint64) {
|
||||
case uint64:
|
||||
return ST_U64;
|
||||
} else if (t == uint32) {
|
||||
case uint32:
|
||||
return ST_U32;
|
||||
} else if (t == uint16) {
|
||||
case uint16:
|
||||
return ST_U16;
|
||||
} else if (t == uint8) {
|
||||
case uint8:
|
||||
return ST_U8;
|
||||
} else if (t == bool_) {
|
||||
case bool_:
|
||||
return ST_BOOL;
|
||||
} else if (t == complex64) {
|
||||
case complex64:
|
||||
return ST_C64;
|
||||
} else {
|
||||
throw std::runtime_error("[safetensor] unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user