mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Merge 8bb8b76ae4
into b8022c578a
This commit is contained in:
commit
6c2c1850a9
@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
|||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
|
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA)
|
|||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_ROCM)
|
||||||
|
enable_language(HIP)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
|
@ -60,7 +60,16 @@ else()
|
|||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
if(MLX_BUILD_ROCM)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_METAL
|
||||||
|
OR MLX_BUILD_CUDA
|
||||||
|
OR MLX_BUILD_ROCM)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
|
85
mlx/backend/rocm/CMakeLists.txt
Normal file
85
mlx/backend/rocm/CMakeLists.txt
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Filename rules in ROCm backend:
|
||||||
|
#
|
||||||
|
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
|
||||||
|
# * Device-only code should be put in device/ subdir.
|
||||||
|
# * Files in device/ subdir should not include files outside.
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
|
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
|
||||||
|
|
||||||
|
# Embed kernel sources in binary for JIT compilation.
|
||||||
|
file(
|
||||||
|
GLOB MLX_JIT_SOURCES
|
||||||
|
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
|
||||||
|
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT gen/rocm_jit_sources.h
|
||||||
|
COMMAND
|
||||||
|
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||||
|
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
|
||||||
|
add_dependencies(mlx rocm_jit_sources)
|
||||||
|
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||||
|
|
||||||
|
# Find ROCm installation
|
||||||
|
find_package(hip REQUIRED)
|
||||||
|
find_package(rocblas REQUIRED)
|
||||||
|
|
||||||
|
# Link with ROCm libraries
|
||||||
|
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)
|
||||||
|
|
||||||
|
# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
|
||||||
|
# gfx908, gfx90a, gfx1030, gfx1100
|
||||||
|
set(MLX_ROCM_ARCHITECTURES
|
||||||
|
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
|
||||||
|
CACHE STRING "ROCm GPU architectures")
|
||||||
|
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")
|
||||||
|
|
||||||
|
# Set GPU targets for HIP compilation
|
||||||
|
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")
|
||||||
|
|
||||||
|
# Enable HIP language support
|
||||||
|
enable_language(HIP)
|
||||||
|
|
||||||
|
# Set HIP compiler flags
|
||||||
|
target_compile_options(
|
||||||
|
mlx
|
||||||
|
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")
|
||||||
|
|
||||||
|
# Add ROCm include directories
|
||||||
|
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})
|
20
mlx/backend/rocm/allocator.cpp
Normal file
20
mlx/backend/rocm/allocator.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/allocator.h"
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void* allocate(size_t size) {
|
||||||
|
void* ptr;
|
||||||
|
check_hip_error("hipMalloc", hipMalloc(&ptr, size));
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void deallocate(void* ptr) {
|
||||||
|
if (ptr) {
|
||||||
|
check_hip_error("hipFree", hipFree(ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
12
mlx/backend/rocm/allocator.h
Normal file
12
mlx/backend/rocm/allocator.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void* allocate(size_t size);
|
||||||
|
void deallocate(void* ptr);
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
28
mlx/backend/rocm/arg_reduce.hip
Normal file
28
mlx/backend/rocm/arg_reduce.hip
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void argmax_kernel(float* input, int* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
// Simple argmax placeholder
|
||||||
|
if (idx == 0) {
|
||||||
|
int max_idx = 0;
|
||||||
|
float max_val = input[0];
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
if (input[i] > max_val) {
|
||||||
|
max_val = input[i];
|
||||||
|
max_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output[0] = max_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_argmax(float* input, int* output, int n, hipStream_t stream) {
|
||||||
|
hipLaunchKernelGGL(argmax_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
47
mlx/backend/rocm/bin2h.cmake
Normal file
47
mlx/backend/rocm/bin2h.cmake
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
# Script to embed kernel source files as header for JIT compilation
|
||||||
|
|
||||||
|
set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h")
|
||||||
|
set(MLX_KERNEL_HEADER
|
||||||
|
"#pragma once\n\n#include <unordered_map>\n#include <string>\n\nnamespace mlx::core::rocm {\n\n"
|
||||||
|
)
|
||||||
|
set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n")
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY)
|
||||||
|
file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR})
|
||||||
|
|
||||||
|
# Write header
|
||||||
|
file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER})
|
||||||
|
|
||||||
|
# Process JIT sources
|
||||||
|
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||||
|
|
||||||
|
set(MLX_SOURCE_MAP
|
||||||
|
"const std::unordered_map<std::string, std::string> kernel_sources = {\n")
|
||||||
|
|
||||||
|
foreach(source IN LISTS MLX_JIT_SOURCES_LIST)
|
||||||
|
set(source_file "${MLX_SOURCE_ROOT}/${source}")
|
||||||
|
if(EXISTS ${source_file})
|
||||||
|
# Read source file
|
||||||
|
file(READ ${source_file} source_content)
|
||||||
|
|
||||||
|
# Escape content for C++ string literal
|
||||||
|
string(REPLACE "\\" "\\\\" source_content "${source_content}")
|
||||||
|
string(REPLACE "\"" "\\\"" source_content "${source_content}")
|
||||||
|
string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}")
|
||||||
|
|
||||||
|
# Add to map
|
||||||
|
set(MLX_SOURCE_MAP
|
||||||
|
"${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n")
|
||||||
|
|
||||||
|
# Write source map
|
||||||
|
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP})
|
||||||
|
|
||||||
|
# Write footer
|
||||||
|
file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER})
|
36
mlx/backend/rocm/binary.hip
Normal file
36
mlx/backend/rocm/binary.hip
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
// Basic binary operation kernels will go here
|
||||||
|
__global__ void add_kernel(float* a, float* b, float* c, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
c[idx] = a[idx] + b[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void multiply_kernel(float* a, float* b, float* c, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
c[idx] = a[idx] * b[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_add(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(add_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_multiply(float* a, float* b, float* c, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(multiply_kernel, dim3(blocks), dim3(threads), 0, stream, a, b, c, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/compiled.cpp
Normal file
9
mlx/backend/rocm/compiled.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void compile() {
|
||||||
|
// Placeholder for ROCm compilation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
20
mlx/backend/rocm/copy.hip
Normal file
20
mlx/backend/rocm/copy.hip
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void copy_kernel(float* src, float* dst, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
dst[idx] = src[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_copy(float* src, float* dst, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(copy_kernel, dim3(blocks), dim3(threads), 0, stream, src, dst, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
104
mlx/backend/rocm/device.cpp
Normal file
104
mlx/backend/rocm/device.cpp
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/device.h"
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
DeviceStream::DeviceStream(Device& device) : device_(device) {
|
||||||
|
check_hip_error("hipStreamCreate", hipStreamCreate(&stream_));
|
||||||
|
encoder_ = std::make_unique<CommandEncoder>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceStream::synchronize() {
|
||||||
|
check_hip_error("hipStreamSynchronize", hipStreamSynchronize(stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
hipStream_t DeviceStream::schedule_hip_stream() {
|
||||||
|
return stream_;
|
||||||
|
}
|
||||||
|
|
||||||
|
hipStream_t DeviceStream::last_hip_stream() {
|
||||||
|
return stream_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandEncoder& DeviceStream::get_encoder() {
|
||||||
|
return *encoder_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::Device(int device) : device_(device) {
|
||||||
|
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||||
|
|
||||||
|
// Get device properties
|
||||||
|
hipDeviceProp_t prop;
|
||||||
|
check_hip_error(
|
||||||
|
"hipGetDeviceProperties", hipGetDeviceProperties(&prop, device_));
|
||||||
|
compute_capability_major_ = prop.major;
|
||||||
|
compute_capability_minor_ = prop.minor;
|
||||||
|
|
||||||
|
// Create rocBLAS handle
|
||||||
|
check_hip_error(
|
||||||
|
"rocblas_create_handle",
|
||||||
|
static_cast<hipError_t>(rocblas_create_handle(&rocblas_handle_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::~Device() {
|
||||||
|
if (rocblas_handle_) {
|
||||||
|
rocblas_destroy_handle(rocblas_handle_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::make_current() {
|
||||||
|
check_hip_error("hipSetDevice", hipSetDevice(device_));
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceStream& Device::get_stream(Stream s) {
|
||||||
|
auto it = streams_.find(s.index);
|
||||||
|
if (it != streams_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [new_it, inserted] = streams_.emplace(s.index, DeviceStream(*this));
|
||||||
|
return new_it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandEncoder::CommandEncoder(DeviceStream& stream)
|
||||||
|
: device_(stream.device()), stream_(stream), worker_() {}
|
||||||
|
|
||||||
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
|
worker_.enqueue(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::end_encoding() {
|
||||||
|
// Implementation for ending encoding
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::commit() {
|
||||||
|
worker_.commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global device management
|
||||||
|
static std::unordered_map<int, std::unique_ptr<Device>> devices_;
|
||||||
|
|
||||||
|
Device& device(mlx::core::Device device) {
|
||||||
|
auto it = devices_.find(device.index);
|
||||||
|
if (it != devices_.end()) {
|
||||||
|
return *it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_device = std::make_unique<Device>(device.index);
|
||||||
|
Device& dev_ref = *new_device;
|
||||||
|
devices_[device.index] = std::move(new_device);
|
||||||
|
return dev_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceStream& get_stream(Stream s) {
|
||||||
|
// Use default device (index 0) for now
|
||||||
|
return device(mlx::core::Device{mlx::core::Device::gpu, 0}).get_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
CommandEncoder& get_command_encoder(Stream s) {
|
||||||
|
return get_stream(s).get_encoder();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
141
mlx/backend/rocm/device.h
Normal file
141
mlx/backend/rocm/device.h
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/rocm/worker.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <rocblas/rocblas.h>
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
class Device;
|
||||||
|
class CommandEncoder;
|
||||||
|
|
||||||
|
class DeviceStream {
|
||||||
|
public:
|
||||||
|
explicit DeviceStream(Device& device);
|
||||||
|
|
||||||
|
DeviceStream(const DeviceStream&) = delete;
|
||||||
|
DeviceStream& operator=(const DeviceStream&) = delete;
|
||||||
|
|
||||||
|
// Wait until kernels in the stream complete.
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
|
// Return a HIP stream for launching kernels.
|
||||||
|
hipStream_t schedule_hip_stream();
|
||||||
|
|
||||||
|
// Return the last HIP stream used.
|
||||||
|
hipStream_t last_hip_stream();
|
||||||
|
|
||||||
|
CommandEncoder& get_encoder();
|
||||||
|
|
||||||
|
Device& device() {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Device& device_;
|
||||||
|
HipStream stream_;
|
||||||
|
std::unique_ptr<CommandEncoder> encoder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Device {
|
||||||
|
public:
|
||||||
|
explicit Device(int device);
|
||||||
|
~Device();
|
||||||
|
|
||||||
|
Device(const Device&) = delete;
|
||||||
|
Device& operator=(const Device&) = delete;
|
||||||
|
|
||||||
|
// Make this device the current HIP device, required by some HIP calls.
|
||||||
|
void make_current();
|
||||||
|
|
||||||
|
DeviceStream& get_stream(Stream s);
|
||||||
|
|
||||||
|
int hip_device() const {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
int compute_capability_major() const {
|
||||||
|
return compute_capability_major_;
|
||||||
|
}
|
||||||
|
int compute_capability_minor() const {
|
||||||
|
return compute_capability_minor_;
|
||||||
|
}
|
||||||
|
rocblas_handle rocblas_handle() const {
|
||||||
|
return rocblas_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int device_;
|
||||||
|
int compute_capability_major_;
|
||||||
|
int compute_capability_minor_;
|
||||||
|
rocblas_handle rocblas_handle_;
|
||||||
|
std::unordered_map<int, DeviceStream> streams_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CommandEncoder {
|
||||||
|
public:
|
||||||
|
explicit CommandEncoder(DeviceStream& stream);
|
||||||
|
|
||||||
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
|
||||||
|
void set_input_array(const array& arr) {}
|
||||||
|
void set_output_array(const array& arr) {}
|
||||||
|
|
||||||
|
void add_temporary(const array& arr) {
|
||||||
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_completed_handler(std::function<void()> task);
|
||||||
|
void end_encoding();
|
||||||
|
void commit();
|
||||||
|
|
||||||
|
// Schedule a HIP stream for |fun| to launch kernels, and check error
|
||||||
|
// afterwards.
|
||||||
|
template <typename F>
|
||||||
|
void launch_kernel(F&& fun) {
|
||||||
|
launch_kernel(stream_.schedule_hip_stream(), std::forward<F>(fun));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void launch_kernel(hipStream_t stream, F&& fun) {
|
||||||
|
device_.make_current();
|
||||||
|
fun(stream);
|
||||||
|
check_hip_error("kernel launch", hipGetLastError());
|
||||||
|
has_gpu_work_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device& device() {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceStream& stream() {
|
||||||
|
return stream_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_gpu_work() const {
|
||||||
|
return has_gpu_work_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Device& device_;
|
||||||
|
DeviceStream& stream_;
|
||||||
|
Worker worker_;
|
||||||
|
bool has_gpu_work_{false};
|
||||||
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Device& device(mlx::core::Device device);
|
||||||
|
DeviceStream& get_stream(Stream s);
|
||||||
|
CommandEncoder& get_command_encoder(Stream s);
|
||||||
|
|
||||||
|
// Utility function to check HIP errors
|
||||||
|
void check_hip_error(const char* msg, hipError_t error);
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/eval.cpp
Normal file
11
mlx/backend/rocm/eval.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void eval() {
|
||||||
|
// Placeholder for ROCm evaluation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
32
mlx/backend/rocm/event.hip
Normal file
32
mlx/backend/rocm/event.hip
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
class Event {
|
||||||
|
public:
|
||||||
|
Event() {
|
||||||
|
check_hip_error("hipEventCreate", hipEventCreate(&event_));
|
||||||
|
}
|
||||||
|
|
||||||
|
~Event() {
|
||||||
|
hipEventDestroy(event_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void record(hipStream_t stream) {
|
||||||
|
check_hip_error("hipEventRecord", hipEventRecord(event_, stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait() {
|
||||||
|
check_hip_error("hipEventSynchronize", hipEventSynchronize(event_));
|
||||||
|
}
|
||||||
|
|
||||||
|
hipEvent_t event() const { return event_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
hipEvent_t event_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/fence.cpp
Normal file
9
mlx/backend/rocm/fence.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void fence() {
|
||||||
|
// Placeholder for ROCm fence operation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/indexing.cpp
Normal file
9
mlx/backend/rocm/indexing.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void index() {
|
||||||
|
// Placeholder for ROCm indexing operation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
29
mlx/backend/rocm/kernel_utils.hip
Normal file
29
mlx/backend/rocm/kernel_utils.hip
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
// Utility functions for HIP kernels
|
||||||
|
|
||||||
|
__device__ inline int get_global_id() {
|
||||||
|
return blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline int get_local_id() {
|
||||||
|
return threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline int get_group_id() {
|
||||||
|
return blockIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline int get_local_size() {
|
||||||
|
return blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline int get_num_groups() {
|
||||||
|
return gridDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
37
mlx/backend/rocm/layer_norm.hip
Normal file
37
mlx/backend/rocm/layer_norm.hip
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void layer_norm_kernel(
|
||||||
|
float* input,
|
||||||
|
float* output,
|
||||||
|
float* gamma,
|
||||||
|
float* beta,
|
||||||
|
int n,
|
||||||
|
float eps) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (idx < n) {
|
||||||
|
// Simplified layer norm placeholder
|
||||||
|
// Real implementation would compute mean and variance
|
||||||
|
output[idx] = gamma[idx] * input[idx] + beta[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_layer_norm(
|
||||||
|
float* input,
|
||||||
|
float* output,
|
||||||
|
float* gamma,
|
||||||
|
float* beta,
|
||||||
|
int n,
|
||||||
|
float eps,
|
||||||
|
hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(layer_norm_kernel, dim3(blocks), dim3(threads), 0, stream,
|
||||||
|
input, output, gamma, beta, n, eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/logsumexp.hip
Normal file
13
mlx/backend/rocm/logsumexp.hip
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void logsumexp_kernel(float* input, float* output, int n) {
|
||||||
|
// Placeholder implementation
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
(void)input; (void)output; (void)n; (void)idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
30
mlx/backend/rocm/matmul.cpp
Normal file
30
mlx/backend/rocm/matmul.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/device.h"
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void matmul_hip(
|
||||||
|
float* a,
|
||||||
|
float* b,
|
||||||
|
float* c,
|
||||||
|
int m,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
hipStream_t stream) {
|
||||||
|
// This is a placeholder - in a real implementation, this would use rocBLAS
|
||||||
|
// auto& device = get_current_device();
|
||||||
|
// rocblas_sgemm(device.rocblas_handle(), ...);
|
||||||
|
|
||||||
|
// For now, just a placeholder
|
||||||
|
(void)a;
|
||||||
|
(void)b;
|
||||||
|
(void)c;
|
||||||
|
(void)m;
|
||||||
|
(void)n;
|
||||||
|
(void)k;
|
||||||
|
(void)stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/no_rocm.cpp
Normal file
11
mlx/backend/rocm/no_rocm.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/rocm.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
21
mlx/backend/rocm/primitives.hip
Normal file
21
mlx/backend/rocm/primitives.hip
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
#include "mlx/backend/common/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
// Basic kernel implementations will go here
|
||||||
|
// This is a placeholder for ROCm-specific primitive operations
|
||||||
|
|
||||||
|
void add_hip() {
|
||||||
|
// Placeholder for HIP add operation
|
||||||
|
}
|
||||||
|
|
||||||
|
void multiply_hip() {
|
||||||
|
// Placeholder for HIP multiply operation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
23
mlx/backend/rocm/random.hip
Normal file
23
mlx/backend/rocm/random.hip
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void random_uniform_kernel(float* output, int n, unsigned int seed) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
// Simple LCG placeholder - real implementation would use rocRAND
|
||||||
|
unsigned int state = seed + idx;
|
||||||
|
state = state * 1103515245 + 12345;
|
||||||
|
output[idx] = (float)(state & 0x7FFFFFFF) / (float)0x7FFFFFFF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_random_uniform(float* output, int n, unsigned int seed, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(random_uniform_kernel, dim3(blocks), dim3(threads), 0, stream, output, n, seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
24
mlx/backend/rocm/reduce.hip
Normal file
24
mlx/backend/rocm/reduce.hip
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void sum_reduce_kernel(float* input, float* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
// Simple reduction placeholder
|
||||||
|
if (idx == 0) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
sum += input[i];
|
||||||
|
}
|
||||||
|
output[0] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_sum_reduce(float* input, float* output, int n, hipStream_t stream) {
|
||||||
|
hipLaunchKernelGGL(sum_reduce_kernel, dim3(1), dim3(1), 0, stream, input, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/rms_norm.hip
Normal file
13
mlx/backend/rocm/rms_norm.hip
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void rms_norm_kernel(float* input, float* output, int n) {
|
||||||
|
// Placeholder implementation
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
(void)input; (void)output; (void)n; (void)idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
11
mlx/backend/rocm/rocm.cpp
Normal file
11
mlx/backend/rocm/rocm.cpp
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/rocm.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
10
mlx/backend/rocm/rocm.h
Normal file
10
mlx/backend/rocm/rocm.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
/* Check if the ROCm backend is available. */
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
13
mlx/backend/rocm/rope.hip
Normal file
13
mlx/backend/rocm/rope.hip
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void rope_kernel(float* input, float* output, int n) {
|
||||||
|
// Placeholder for RoPE implementation
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
(void)input; (void)output; (void)n; (void)idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
9
mlx/backend/rocm/slicing.cpp
Normal file
9
mlx/backend/rocm/slicing.cpp
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void slice() {
|
||||||
|
// Placeholder for ROCm slicing operation
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
22
mlx/backend/rocm/softmax.hip
Normal file
22
mlx/backend/rocm/softmax.hip
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void softmax_kernel(float* input, float* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (idx < n) {
|
||||||
|
// Simplified softmax placeholder - real implementation needs reduction
|
||||||
|
output[idx] = expf(input[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_softmax(float* input, float* output, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(softmax_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
1
mlx/backend/rocm/sort.hip
Normal file
1
mlx/backend/rocm/sort.hip
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
20
mlx/backend/rocm/ternary.hip
Normal file
20
mlx/backend/rocm/ternary.hip
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void select_kernel(float* condition, float* a, float* b, float* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
output[idx] = (condition[idx] != 0.0f) ? a[idx] : b[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_select(float* condition, float* a, float* b, float* output, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(select_kernel, dim3(blocks), dim3(threads), 0, stream, condition, a, b, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
33
mlx/backend/rocm/unary.hip
Normal file
33
mlx/backend/rocm/unary.hip
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
__global__ void relu_kernel(float* input, float* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
output[idx] = fmaxf(0.0f, input[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void sigmoid_kernel(float* input, float* output, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) {
|
||||||
|
output[idx] = 1.0f / (1.0f + expf(-input[idx]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_relu(float* input, float* output, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(relu_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_sigmoid(float* input, float* output, int n, hipStream_t stream) {
|
||||||
|
int threads = 256;
|
||||||
|
int blocks = (n + threads - 1) / threads;
|
||||||
|
hipLaunchKernelGGL(sigmoid_kernel, dim3(blocks), dim3(threads), 0, stream, input, output, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
17
mlx/backend/rocm/utils.cpp
Normal file
17
mlx/backend/rocm/utils.cpp
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/utils.h"
|
||||||
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
void check_hip_error(const char* msg, hipError_t error) {
|
||||||
|
if (error != hipSuccess) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[ROCm] " << msg << ": " << hipGetErrorString(error);
|
||||||
|
throw std::runtime_error(oss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
12
mlx/backend/rocm/utils.h
Normal file
12
mlx/backend/rocm/utils.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
// Utility function to check HIP errors
|
||||||
|
void check_hip_error(const char* msg, hipError_t error);
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
61
mlx/backend/rocm/worker.cpp
Normal file
61
mlx/backend/rocm/worker.cpp
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/rocm/worker.h"
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
Worker::Worker() : worker_thread_(&Worker::worker_loop, this) {}
|
||||||
|
|
||||||
|
Worker::~Worker() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
stop_ = true;
|
||||||
|
}
|
||||||
|
cv_.notify_all();
|
||||||
|
if (worker_thread_.joinable()) {
|
||||||
|
worker_thread_.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Worker::enqueue(std::function<void()> task) {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
tasks_.push(task);
|
||||||
|
}
|
||||||
|
cv_.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Worker::commit() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
committed_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Worker::join() {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [this] { return tasks_.empty() && committed_; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Worker::worker_loop() {
|
||||||
|
while (true) {
|
||||||
|
std::function<void()> task;
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [this] { return stop_ || !tasks_.empty(); });
|
||||||
|
|
||||||
|
if (stop_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tasks_.empty()) {
|
||||||
|
task = tasks_.front();
|
||||||
|
tasks_.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (task) {
|
||||||
|
task();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
38
mlx/backend/rocm/worker.h
Normal file
38
mlx/backend/rocm/worker.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <future>
|
||||||
|
#include <queue>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
namespace mlx::core::rocm {
|
||||||
|
|
||||||
|
using HipStream = hipStream_t;
|
||||||
|
|
||||||
|
class Worker {
|
||||||
|
public:
|
||||||
|
Worker();
|
||||||
|
~Worker();
|
||||||
|
|
||||||
|
Worker(const Worker&) = delete;
|
||||||
|
Worker& operator=(const Worker&) = delete;
|
||||||
|
|
||||||
|
void enqueue(std::function<void()> task);
|
||||||
|
void commit();
|
||||||
|
void join();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void worker_loop();
|
||||||
|
|
||||||
|
std::thread worker_thread_;
|
||||||
|
std::queue<std::function<void()>> tasks_;
|
||||||
|
std::mutex mutex_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
bool stop_{false};
|
||||||
|
bool committed_{false};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::rocm
|
@ -6,10 +6,23 @@
|
|||||||
#include "mlx/backend/gpu/available.h"
|
#include "mlx/backend/gpu/available.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
|
|
||||||
|
#ifdef MLX_USE_ROCM
|
||||||
|
#include "mlx/backend/rocm/rocm.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
Device& mutable_default_device() {
|
Device& mutable_default_device() {
|
||||||
static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
|
Device::DeviceType default_type = Device::cpu;
|
||||||
|
if (gpu::is_available()) {
|
||||||
|
default_type = Device::gpu;
|
||||||
|
}
|
||||||
|
#ifdef MLX_USE_ROCM
|
||||||
|
else if (rocm::is_available()) {
|
||||||
|
default_type = Device::gpu; // ROCm devices use the generic gpu type
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
static Device default_device{default_type};
|
||||||
return default_device;
|
return default_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +51,11 @@ bool is_available(const Device& d) {
|
|||||||
case Device::cpu:
|
case Device::cpu:
|
||||||
return cpu::is_available();
|
return cpu::is_available();
|
||||||
case Device::gpu:
|
case Device::gpu:
|
||||||
|
#ifdef MLX_USE_ROCM
|
||||||
|
return gpu::is_available() || rocm::is_available();
|
||||||
|
#else
|
||||||
return gpu::is_available();
|
return gpu::is_available();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
// appease compiler
|
// appease compiler
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
Reference in New Issue
Block a user