switch statement

This commit is contained in:
Awni Hannun 2023-12-22 21:15:11 -08:00
parent 313f6bd9b1
commit 680f18cca5

View File

@ -5,34 +5,33 @@
namespace mlx::core { namespace mlx::core {
std::string dtype_to_safetensor_str(Dtype t) { std::string dtype_to_safetensor_str(Dtype t) {
if (t == float32) { switch (t) {
return ST_F32; case float32:
} else if (t == bfloat16) { return ST_F32;
return ST_BF16; case bfloat16:
} else if (t == float16) { return ST_BF16;
return ST_F16; case float16:
} else if (t == int64) { return ST_F16;
return ST_I64; case int64:
} else if (t == int32) { return ST_I64;
return ST_I32; case int32:
} else if (t == int16) { return ST_I32;
return ST_I16; case int16:
} else if (t == int8) { return ST_I16;
return ST_I8; case int8:
} else if (t == uint64) { return ST_I8;
return ST_U64; case uint64:
} else if (t == uint32) { return ST_U64;
return ST_U32; case uint32:
} else if (t == uint16) { return ST_U32;
return ST_U16; case uint16:
} else if (t == uint8) { return ST_U16;
return ST_U8; case uint8:
} else if (t == bool_) { return ST_U8;
return ST_BOOL; case bool_:
} else if (t == complex64) { return ST_BOOL;
return ST_C64; case complex64:
} else { return ST_C64;
throw std::runtime_error("[safetensor] unsupported dtype");
} }
} }