mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Merge 8bb8b76ae4
into cad5c0241c
This commit is contained in:
commit
5b8da0c92d
@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ROCM)
|
||||
enable_language(HIP)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
|
@ -60,7 +60,16 @@ else()
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
if(MLX_BUILD_ROCM)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL
|
||||
OR MLX_BUILD_CUDA
|
||||
OR MLX_BUILD_ROCM)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
|
85
mlx/backend/rocm/CMakeLists.txt
Normal file
85
mlx/backend/rocm/CMakeLists.txt
Normal file
@ -0,0 +1,85 @@
|
||||
# Filename rules in ROCm backend:
|
||||
#
|
||||
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/rocm_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
|
||||
add_dependencies(mlx rocm_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Find ROCm installation
|
||||
find_package(hip REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
|
||||
# Link with ROCm libraries
|
||||
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)
|
||||
|
||||
# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
|
||||
# gfx908, gfx90a, gfx1030, gfx1100
|
||||
set(MLX_ROCM_ARCHITECTURES
|
||||
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
|
||||
CACHE STRING "ROCm GPU architectures")
|
||||
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")
|
||||
|
||||
# Set GPU targets for HIP compilation
|
||||
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")
|
||||
|
||||
# Enable HIP language support
|
||||
enable_language(HIP)
|
||||
|
||||
# Set HIP compiler flags
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
|
||||
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
|
||||
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")
|
||||
|
||||
# Add ROCm include directories
|
||||
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})
|
20
mlx/backend/rocm/allocator.cpp
Normal file
20
mlx/backend/rocm/allocator.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/allocator.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void* allocate(size_t size) {
|
||||
void* ptr;
|
||||
check_hip_error("hipMalloc", hipMalloc(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void deallocate(void* ptr) {
|
||||
if (ptr) {
|
||||
check_hip_error("hipFree", hipFree(ptr));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
12
mlx/backend/rocm/allocator.h
Normal file
12
mlx/backend/rocm/allocator.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void* allocate(size_t size);
|
||||
void deallocate(void* ptr);
|
||||
|
||||
} // namespace mlx::core::rocm
|
28
mlx/backend/rocm/arg_reduce.hip
Normal file
28
mlx/backend/rocm/arg_reduce.hip
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void argmax_kernel(float* input, int* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Simple argmax placeholder
|
||||
if (idx == 0) {
|
||||
int max_idx = 0;
|
||||
float max_val = input[0];
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (input[i] > max_val) {
|
||||
max_val = input[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
output[0] = max_idx;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_argmax(float* input, int* output, int n, hipStream_t stream) {
|
||||
hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
47
mlx/backend/rocm/bin2h.cmake
Normal file
47
mlx/backend/rocm/bin2h.cmake
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
# Script to embed kernel source files as header for JIT compilation
|
||||
|
||||
set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h")
|
||||
set(MLX_KERNEL_HEADER
|
||||
"#pragma once\n\n#include <unordered_map>\n#include <string>\n\nnamespace mlx::core::rocm {\n\n"
|
||||
)
|
||||
set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n")
|
||||
|
||||
# Create output directory
|
||||
get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY)
|
||||
file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR})
|
||||
|
||||
# Write header
|
||||
file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER})
|
||||
|
||||
# Process JIT sources
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
|
||||
set(MLX_SOURCE_MAP
|
||||
"const std::unordered_map<std::string, std::string> kernel_sources = {\n")
|
||||
|
||||
foreach(source IN LISTS MLX_JIT_SOURCES_LIST)
|
||||
set(source_file "${MLX_SOURCE_ROOT}/${source}")
|
||||
if(EXISTS ${source_file})
|
||||
# Read source file
|
||||
file(READ ${source_file} source_content)
|
||||
|
||||
# Escape content for C++ string literal
|
||||
string(REPLACE "\\" "\\\\" source_content "${source_content}")
|
||||
string(REPLACE "\"" "\\\"" source_content "${source_content}")
|
||||
string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}")
|
||||
|
||||
# Add to map
|
||||
set(MLX_SOURCE_MAP
|
||||
"${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n")
|
||||
|
||||
# Write source map
|
||||
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP})
|
||||
|
||||
# Write footer
|
||||
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER})
|
36
mlx/backend/rocm/binary.hip
Normal file
36
mlx/backend/rocm/binary.hip
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Basic binary operation kernels will go here
|
||||
__global__ void add_kernel(float* a, float* b, float* c, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
c[idx] = a[idx] + b[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void multiply_kernel(float* a, float* b, float* c, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
c[idx] = a[idx] * b[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||
}
|
||||
|
||||
void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/compiled.cpp
Normal file
9
mlx/backend/rocm/compiled.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void compile() {
|
||||
// Placeholder for ROCm compilation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
20
mlx/backend/rocm/copy.hip
Normal file
20
mlx/backend/rocm/copy.hip
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void copy_kernel(float* src, float* dst, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
dst[idx] = src[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_copy(float* src, float* dst, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
104
mlx/backend/rocm/device.cpp
Normal file
104
mlx/backend/rocm/device.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
DeviceStream::DeviceStream(Device& device) : device_(device) {
|
||||
check_hip_error("hipStreamCreate", hipStreamCreate(&stream_));
|
||||
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||
}
|
||||
|
||||
void DeviceStream::synchronize() {
|
||||
check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_));
|
||||
}
|
||||
|
||||
hipStream_t DeviceStream::schedule_hip_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
hipStream_t DeviceStream::last_hip_stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
CommandEncoder& DeviceStream::get_encoder() {
|
||||
return *encoder_;
|
||||
}
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||
|
||||
// Get device properties
|
||||
hipDeviceProp_t prop;
|
||||
check_hip_error(
|
||||
"hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_));
|
||||
compute_capability_major_ = prop.major;
|
||||
compute_capability_minor_ = prop.minor;
|
||||
|
||||
// Create rocBLAS handle
|
||||
check_hip_error(
|
||||
"rocblas_create_handle",
|
||||
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
if (rocblas_handle_) {
|
||||
rocblas_destroy_handle(rocblas_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||
}
|
||||
|
||||
DeviceStream& Device::get_stream(Stream s) {
|
||||
auto it = streams_.find(s.index);
|
||||
if (it != streams_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this));
|
||||
return new_it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(DeviceStream& stream)
|
||||
: device_(stream.device()), stream_(stream), worker_() {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.enqueue(task);
|
||||
}
|
||||
|
||||
void CommandEncoder::end_encoding() {
|
||||
// Implementation for ending encoding
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
worker_.commit();
|
||||
}
|
||||
|
||||
// Global device management
|
||||
static std::unordered_map<int, std::unique_ptr<Device>> devices_;
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
auto it = devices_.find(device.index);
|
||||
if (it != devices_.end()) {
|
||||
return *it->second;
|
||||
}
|
||||
|
||||
auto new_device = std::make_unique<Device>(device.index);
|
||||
Device& dev_ref = *new_device;
|
||||
devices_[device.index] = std::move(new_device);
|
||||
return dev_ref;
|
||||
}
|
||||
|
||||
DeviceStream& get_stream(Stream s) {
|
||||
// Use default device (index 0) for now
|
||||
return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s);
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return get_stream(s).get_encoder();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
141
mlx/backend/rocm/device.h
Normal file
141
mlx/backend/rocm/device.h
Normal file
@ -0,0 +1,141 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
class Device;
|
||||
class CommandEncoder;
|
||||
|
||||
class DeviceStream {
|
||||
public:
|
||||
explicit DeviceStream(Device& device);
|
||||
|
||||
DeviceStream(const DeviceStream&) = delete;
|
||||
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||
|
||||
// Wait until kernels in the stream complete.
|
||||
void synchronize();
|
||||
|
||||
// Return a HIP stream for launching kernels.
|
||||
hipStream_t schedule_hip_stream();
|
||||
|
||||
// Return the last HIP stream used.
|
||||
hipStream_t last_hip_stream();
|
||||
|
||||
CommandEncoder& get_encoder();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
HipStream stream_;
|
||||
std::unique_ptr<CommandEncoder> encoder_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
~Device();
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current HIP device, required by some HIP calls.
|
||||
void make_current();
|
||||
|
||||
DeviceStream& get_stream(Stream s);
|
||||
|
||||
int hip_device() const {
|
||||
return device_;
|
||||
}
|
||||
int compute_capability_major() const {
|
||||
return compute_capability_major_;
|
||||
}
|
||||
int compute_capability_minor() const {
|
||||
return compute_capability_minor_;
|
||||
}
|
||||
rocblas_handle rocblas_handle() const {
|
||||
return rocblas_handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
rocblas_handle rocblas_handle_;
|
||||
std::unordered_map<int, DeviceStream> streams_;
|
||||
};
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
explicit CommandEncoder(DeviceStream& stream);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
void set_input_array(const array& arr) {}
|
||||
void set_output_array(const array& arr) {}
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void end_encoding();
|
||||
void commit();
|
||||
|
||||
// Schedule a HIP stream for |fun| to launch kernels, and check error
|
||||
// afterwards.
|
||||
template <typename F>
|
||||
void launch_kernel(F&& fun) {
|
||||
launch_kernel(stream_.schedule_hip_stream(), std::forward<F>(fun));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void launch_kernel(hipStream_t stream, F&& fun) {
|
||||
device_.make_current();
|
||||
fun(stream);
|
||||
check_hip_error("kernel launch", hipGetLastError());
|
||||
has_gpu_work_ = true;
|
||||
}
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
DeviceStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
bool has_gpu_work() const {
|
||||
return has_gpu_work_;
|
||||
}
|
||||
|
||||
private:
|
||||
Device& device_;
|
||||
DeviceStream& stream_;
|
||||
Worker worker_;
|
||||
bool has_gpu_work_{false};
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
DeviceStream& get_stream(Stream s);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Utility function to check HIP errors
|
||||
void check_hip_error(const char* msg, hipError_t error);
|
||||
|
||||
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/eval.cpp
Normal file
11
mlx/backend/rocm/eval.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void eval() {
|
||||
// Placeholder for ROCm evaluation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
32
mlx/backend/rocm/event.hip
Normal file
32
mlx/backend/rocm/event.hip
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
class Event {
|
||||
public:
|
||||
Event() {
|
||||
check_hip_error("hipEventCreate", hipEventCreate(&event_));
|
||||
}
|
||||
|
||||
~Event() {
|
||||
hipEventDestroy(event_);
|
||||
}
|
||||
|
||||
void record(hipStream_t stream) {
|
||||
check_hip_error("hipEventRecord", hipEventRecord(event_, stream));
|
||||
}
|
||||
|
||||
void wait() {
|
||||
check_hip_error("hipEventSynchronize", hipEventSynchronize(event_));
|
||||
}
|
||||
|
||||
hipEvent_t event() const { return event_; }
|
||||
|
||||
private:
|
||||
hipEvent_t event_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/fence.cpp
Normal file
9
mlx/backend/rocm/fence.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void fence() {
|
||||
// Placeholder for ROCm fence operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/indexing.cpp
Normal file
9
mlx/backend/rocm/indexing.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void index() {
|
||||
// Placeholder for ROCm indexing operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
29
mlx/backend/rocm/kernel_utils.hip
Normal file
29
mlx/backend/rocm/kernel_utils.hip
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Utility functions for HIP kernels
|
||||
|
||||
__device__ inline int get_global_id() {
|
||||
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_local_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_group_id() {
|
||||
return blockIdx.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_local_size() {
|
||||
return blockDim.x;
|
||||
}
|
||||
|
||||
__device__ inline int get_num_groups() {
|
||||
return gridDim.x;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
37
mlx/backend/rocm/layer_norm.hip
Normal file
37
mlx/backend/rocm/layer_norm.hip
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void layer_norm_kernel(
|
||||
float* input,
|
||||
float* output,
|
||||
float* gamma,
|
||||
float* beta,
|
||||
int n,
|
||||
float eps) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (idx < n) {
|
||||
// Simplified layer norm placeholder
|
||||
// Real implementation would compute mean and variance
|
||||
output[idx] = gamma[idx] * input[idx] + beta[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_layer_norm(
|
||||
float* input,
|
||||
float* output,
|
||||
float* gamma,
|
||||
float* beta,
|
||||
int n,
|
||||
float eps,
|
||||
hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream,
|
||||
input, output, gamma, beta, n, eps);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/logsumexp.hip
Normal file
13
mlx/backend/rocm/logsumexp.hip
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void logsumexp_kernel(float* input, float* output, int n) {
|
||||
// Placeholder implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
30
mlx/backend/rocm/matmul.cpp
Normal file
30
mlx/backend/rocm/matmul.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/device.h"
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void matmul_hip(
|
||||
float* a,
|
||||
float* b,
|
||||
float* c,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
hipStream_t stream) {
|
||||
// This is a placeholder - in a real implementation, this would use rocBLAS
|
||||
// auto& device = get_current_device();
|
||||
// rocblas_sgemm(device.rocblas_handle(), ...);
|
||||
|
||||
// For now, just a placeholder
|
||||
(void)a;
|
||||
(void)b;
|
||||
(void)c;
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
(void)stream;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/no_rocm.cpp
Normal file
11
mlx/backend/rocm/no_rocm.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/rocm.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
21
mlx/backend/rocm/primitives.hip
Normal file
21
mlx/backend/rocm/primitives.hip
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include "mlx/backend/common/primitives.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Basic kernel implementations will go here
|
||||
// This is a placeholder for ROCm-specific primitive operations
|
||||
|
||||
void add_hip() {
|
||||
// Placeholder for HIP add operation
|
||||
}
|
||||
|
||||
void multiply_hip() {
|
||||
// Placeholder for HIP multiply operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
23
mlx/backend/rocm/random.hip
Normal file
23
mlx/backend/rocm/random.hip
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
// Simple LCG placeholder - real implementation would use rocRAND
|
||||
unsigned int state = seed + idx;
|
||||
state = state * 1103515245 + 12345;
|
||||
output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
24
mlx/backend/rocm/reduce.hip
Normal file
24
mlx/backend/rocm/reduce.hip
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void sum_reduce_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Simple reduction placeholder
|
||||
if (idx == 0) {
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
sum += input[i];
|
||||
}
|
||||
output[0] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) {
|
||||
hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/rms_norm.hip
Normal file
13
mlx/backend/rocm/rms_norm.hip
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void rms_norm_kernel(float* input, float* output, int n) {
|
||||
// Placeholder implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/rocm.cpp
Normal file
11
mlx/backend/rocm/rocm.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/rocm.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
10
mlx/backend/rocm/rocm.h
Normal file
10
mlx/backend/rocm/rocm.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
/* Check if the ROCm backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/rope.hip
Normal file
13
mlx/backend/rocm/rope.hip
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void rope_kernel(float* input, float* output, int n) {
|
||||
// Placeholder for RoPE implementation
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
(void)input; (void)output; (void)n; (void)idx;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/slicing.cpp
Normal file
9
mlx/backend/rocm/slicing.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void slice() {
|
||||
// Placeholder for ROCm slicing operation
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
22
mlx/backend/rocm/softmax.hip
Normal file
22
mlx/backend/rocm/softmax.hip
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void softmax_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (idx < n) {
|
||||
// Simplified softmax placeholder - real implementation needs reduction
|
||||
output[idx] = expf(input[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_softmax(float* input, float* output, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(softmax_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
1
mlx/backend/rocm/sort.hip
Normal file
1
mlx/backend/rocm/sort.hip
Normal file
@ -0,0 +1 @@
|
||||
|
20
mlx/backend/rocm/ternary.hip
Normal file
20
mlx/backend/rocm/ternary.hip
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
33
mlx/backend/rocm/unary.hip
Normal file
33
mlx/backend/rocm/unary.hip
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
__global__ void relu_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
output[idx] = fmaxf(0.0f, input[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void sigmoid_kernel(float* input, float* output, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
output[idx] = 1.0f / (1.0f + expf(-input[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
void launch_relu(float* input, float* output, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) {
|
||||
int threads = 256;
|
||||
int blocks = (n + threads - 1) / threads;
|
||||
hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
17
mlx/backend/rocm/utils.cpp
Normal file
17
mlx/backend/rocm/utils.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/utils.h"
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
void check_hip_error(const char* msg, hipError_t error) {
|
||||
if (error != hipSuccess) {
|
||||
std::ostringstream oss;
|
||||
oss << "[ROCm] " << msg << ": " << hipGetErrorString(error);
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
12
mlx/backend/rocm/utils.h
Normal file
12
mlx/backend/rocm/utils.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
// Utility function to check HIP errors
|
||||
void check_hip_error(const char* msg, hipError_t error);
|
||||
|
||||
} // namespace mlx::core::rocm
|
61
mlx/backend/rocm/worker.cpp
Normal file
61
mlx/backend/rocm/worker.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/rocm/worker.h"
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
stop_ = true;
|
||||
}
|
||||
cv_.notify_all();
|
||||
if (worker_thread_.joinable()) {
|
||||
worker_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::enqueue(std::function<void()> task) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
tasks_.push(task);
|
||||
}
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
committed_ = true;
|
||||
}
|
||||
|
||||
void Worker::join() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return tasks_.empty() && committed_; });
|
||||
}
|
||||
|
||||
void Worker::worker_loop() {
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); });
|
||||
|
||||
if (stop_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!tasks_.empty()) {
|
||||
task = tasks_.front();
|
||||
tasks_.pop();
|
||||
}
|
||||
}
|
||||
|
||||
if (task) {
|
||||
task();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::rocm
|
38
mlx/backend/rocm/worker.h
Normal file
38
mlx/backend/rocm/worker.h
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
namespace mlx::core::rocm {
|
||||
|
||||
using HipStream = hipStream_t;
|
||||
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
Worker& operator=(const Worker&) = delete;
|
||||
|
||||
void enqueue(std::function<void()> task);
|
||||
void commit();
|
||||
void join();
|
||||
|
||||
private:
|
||||
void worker_loop();
|
||||
|
||||
std::thread worker_thread_;
|
||||
std::queue<std::function<void()>> tasks_;
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool stop_{false};
|
||||
bool committed_{false};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::rocm
|
@ -6,10 +6,23 @@
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
#ifdef MLX_USE_ROCM
|
||||
#include "mlx/backend/rocm/rocm.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Device& mutable_default_device() {
|
||||
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
|
||||
Device::DeviceType default_type = Device::cpu;
|
||||
if (gpu::is_available()) {
|
||||
default_type = Device::gpu;
|
||||
}
|
||||
#ifdef MLX_USE_ROCM
|
||||
else if (rocm::is_available()) {
|
||||
default_type = Device::gpu; // ROCm devices use the generic gpu type
|
||||
}
|
||||
#endif
|
||||
static Device default_device{default_type};
|
||||
return default_device;
|
||||
}
|
||||
|
||||
@ -38,7 +51,11 @@ bool is_available(const Device& d) {
|
||||
case Device::cpu:
|
||||
return cpu::is_available();
|
||||
case Device::gpu:
|
||||
#ifdef MLX_USE_ROCM
|
||||
return gpu::is_available() || rocm::is_available();
|
||||
#else
|
||||
return gpu::is_available();
|
||||
#endif
|
||||
}
|
||||
// appease compiler
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user