From 680f18cca5f31c46148e17b0862a069caeebecba Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 22 Dec 2023 21:15:11 -0800 Subject: [PATCH] switch statement --- mlx/io/safetensor.cpp | 55 +++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/mlx/io/safetensor.cpp b/mlx/io/safetensor.cpp index c17f713e8..a690e6420 100644 --- a/mlx/io/safetensor.cpp +++ b/mlx/io/safetensor.cpp @@ -5,34 +5,33 @@ namespace mlx::core { std::string dtype_to_safetensor_str(Dtype t) { - if (t == float32) { - return ST_F32; - } else if (t == bfloat16) { - return ST_BF16; - } else if (t == float16) { - return ST_F16; - } else if (t == int64) { - return ST_I64; - } else if (t == int32) { - return ST_I32; - } else if (t == int16) { - return ST_I16; - } else if (t == int8) { - return ST_I8; - } else if (t == uint64) { - return ST_U64; - } else if (t == uint32) { - return ST_U32; - } else if (t == uint16) { - return ST_U16; - } else if (t == uint8) { - return ST_U8; - } else if (t == bool_) { - return ST_BOOL; - } else if (t == complex64) { - return ST_C64; - } else { - throw std::runtime_error("[safetensor] unsupported dtype"); + switch (t) { + case float32: + return ST_F32; + case bfloat16: + return ST_BF16; + case float16: + return ST_F16; + case int64: + return ST_I64; + case int32: + return ST_I32; + case int16: + return ST_I16; + case int8: + return ST_I8; + case uint64: + return ST_U64; + case uint32: + return ST_U32; + case uint16: + return ST_U16; + case uint8: + return ST_U8; + case bool_: + return ST_BOOL; + case complex64: + return ST_C64; } }