diff --git a/mlx/safetensor.cpp b/mlx/safetensor.cpp index 7aca0c691..30f5d2b5a 100644 --- a/mlx/safetensor.cpp +++ b/mlx/safetensor.cpp @@ -306,7 +306,7 @@ std::unordered_map load_safetensor( Dtype type = dtype_from_safetensor_str(dtype); auto loaded_array = array( shape_vec, - float32, + type, std::make_unique( to_stream(s), in_stream,