diff --git a/llama/convert.py b/llama/convert.py index 2ce247a3..69168493 100644 --- a/llama/convert.py +++ b/llama/convert.py @@ -46,7 +46,7 @@ if __name__ == "__main__": parser.add_argument("output_file") args = parser.parse_args() - state = torch.load(args.torch_weights) + state = torch.load(args.torch_weights, map_location=torch.device('cpu')) np.savez( args.output_file, **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}