switch statement

This commit is contained in:
Awni Hannun 2023-12-22 21:15:11 -08:00
parent 313f6bd9b1
commit 680f18cca5

View File

@ -5,34 +5,33 @@
namespace mlx::core { namespace mlx::core {
std::string dtype_to_safetensor_str(Dtype t) { std::string dtype_to_safetensor_str(Dtype t) {
if (t == float32) { switch (t) {
case float32:
return ST_F32; return ST_F32;
} else if (t == bfloat16) { case bfloat16:
return ST_BF16; return ST_BF16;
} else if (t == float16) { case float16:
return ST_F16; return ST_F16;
} else if (t == int64) { case int64:
return ST_I64; return ST_I64;
} else if (t == int32) { case int32:
return ST_I32; return ST_I32;
} else if (t == int16) { case int16:
return ST_I16; return ST_I16;
} else if (t == int8) { case int8:
return ST_I8; return ST_I8;
} else if (t == uint64) { case uint64:
return ST_U64; return ST_U64;
} else if (t == uint32) { case uint32:
return ST_U32; return ST_U32;
} else if (t == uint16) { case uint16:
return ST_U16; return ST_U16;
} else if (t == uint8) { case uint8:
return ST_U8; return ST_U8;
} else if (t == bool_) { case bool_:
return ST_BOOL; return ST_BOOL;
} else if (t == complex64) { case complex64:
return ST_C64; return ST_C64;
} else {
throw std::runtime_error("[safetensor] unsupported dtype");
} }
} }