load eval gpu for cuda

This commit is contained in:
Awni Hannun
2025-11-01 06:08:58 -07:00
parent d378567cc6
commit c27a0647a3
11 changed files with 119 additions and 38 deletions

View File

@@ -13,6 +13,7 @@
#include <windows.h>
#endif // _WIN32
#include "mlx/backend/cuda/cuda.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
@@ -226,10 +227,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
auto stream = to_stream(s, Device::cpu);
if (stream.device != Device::cpu) {
throw std::runtime_error("[load] Must run on a CPU stream.");
}
auto stream = to_stream(s, cu::is_available() ? Device::gpu : Device::cpu);
////////////////////////////////////////////////////////
// Read header and prepare array details