mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
Merge ef88907c63
into f5f65ef48c
This commit is contained in:
commit
03d26fabeb
@ -1,13 +1,14 @@
|
|||||||
# Filename rules in cuda backend:
|
# Filename rules in cuda backend:
|
||||||
#
|
#
|
||||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||||
# * Device-only kernel code should be put in kernels/ subdir.
|
# * Device-only code should be put in device/ subdir.
|
||||||
# * Files in kernels/ subdir should not include files outside.
|
# * Files in device/ subdir should not include files outside.
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
@ -18,6 +19,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
@ -37,6 +40,24 @@ target_sources(
|
|||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
|
# Embed kernel sources in binary for JIT compilation.
|
||||||
|
file(
|
||||||
|
GLOB MLX_JIT_SOURCES
|
||||||
|
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||||
|
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT gen/cuda_jit_sources.h
|
||||||
|
COMMAND
|
||||||
|
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||||
|
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||||
|
add_dependencies(mlx cuda_jit_sources)
|
||||||
|
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||||
|
|
||||||
# Enable defining device lambda functions.
|
# Enable defining device lambda functions.
|
||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
@ -87,6 +108,9 @@ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
|||||||
# Use cublasLt.
|
# Use cublasLt.
|
||||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||||
|
|
||||||
|
# Use NVRTC and driver APIs.
|
||||||
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||||
|
#
|
||||||
|
# Copyright 2020 Sivachandran Paramasivam
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
# Function to wrap a given string into multiple lines at the given column
|
||||||
|
# position.
|
||||||
|
#
|
||||||
|
# Parameters:
|
||||||
|
#
|
||||||
|
# * VARIABLE - The name of the CMake variable holding the string.
|
||||||
|
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||||
|
function(WRAP_STRING)
|
||||||
|
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||||
|
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||||
|
|
||||||
|
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||||
|
math(EXPR offset "0")
|
||||||
|
|
||||||
|
while(stringLength GREATER 0)
|
||||||
|
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||||
|
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||||
|
else()
|
||||||
|
math(EXPR length "${stringLength}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||||
|
set(lines "${lines}\n ${line}")
|
||||||
|
|
||||||
|
math(EXPR stringLength "${stringLength} - ${length}")
|
||||||
|
math(EXPR offset "${offset} + ${length}")
|
||||||
|
endwhile()
|
||||||
|
|
||||||
|
set(${WRAP_STRING_VARIABLE}
|
||||||
|
"${lines}"
|
||||||
|
PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||||
|
# The header file will contain a byte array and integer variable holding the
|
||||||
|
# size of the array.
|
||||||
|
#
|
||||||
|
# Parameters:
|
||||||
|
#
|
||||||
|
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||||
|
# the header file.
|
||||||
|
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||||
|
# "_SIZE" will be append to this name and will be used a variable name for
|
||||||
|
# size variable.
|
||||||
|
# * HEADER_FILE - The path of header file.
|
||||||
|
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||||
|
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||||
|
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||||
|
# array.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
#
|
||||||
|
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||||
|
function(BIN2H)
|
||||||
|
set(options APPEND NULL_TERMINATE)
|
||||||
|
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||||
|
set(multiValueArgs SOURCE_FILES)
|
||||||
|
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||||
|
"${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
|
set(arrayDefinition "")
|
||||||
|
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||||
|
# get filename without extension
|
||||||
|
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||||
|
# convert the filename to a valid C identifier
|
||||||
|
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||||
|
|
||||||
|
# reads source file contents as hex string
|
||||||
|
file(READ ${SOURCE_FILE} hexString HEX)
|
||||||
|
|
||||||
|
# append null
|
||||||
|
if(BIN2H_NULL_TERMINATE)
|
||||||
|
string(APPEND hexString "00")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# wraps the hex string into multiple lines
|
||||||
|
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||||
|
|
||||||
|
# strip the © in source code
|
||||||
|
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||||
|
|
||||||
|
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||||
|
${arrayValues})
|
||||||
|
|
||||||
|
# make a full variable name for the array
|
||||||
|
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||||
|
|
||||||
|
# declares byte array and the length variables
|
||||||
|
string(APPEND arrayDefinition
|
||||||
|
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
# add namespace wrapper if defined
|
||||||
|
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||||
|
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||||
|
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||||
|
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(arrayIncludes "#pragma once")
|
||||||
|
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||||
|
|
||||||
|
if(BIN2H_APPEND)
|
||||||
|
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||||
|
else()
|
||||||
|
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# ----------------------------- CLI args -----------------------------
|
||||||
|
|
||||||
|
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||||
|
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||||
|
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
bin2h(
|
||||||
|
SOURCE_FILES
|
||||||
|
${MLX_JIT_SOURCES_ABS}
|
||||||
|
NULL_TERMINATE
|
||||||
|
VARIABLE_NAME
|
||||||
|
"jit_source"
|
||||||
|
HEADER_NAMESPACE
|
||||||
|
"mlx::core"
|
||||||
|
HEADER_FILE
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
228
mlx/backend/cuda/compiled.cpp
Normal file
228
mlx/backend/cuda/compiled.cpp
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
struct FusedKernelBuilder {
|
||||||
|
std::string os;
|
||||||
|
const std::string& kernel_name;
|
||||||
|
const std::vector<array>& inputs;
|
||||||
|
const std::vector<array>& outputs;
|
||||||
|
const std::vector<array>& tape;
|
||||||
|
const std::function<bool(size_t)>& is_constant;
|
||||||
|
|
||||||
|
void build(const char* name, bool contiguous) {
|
||||||
|
NodeNamer namer;
|
||||||
|
|
||||||
|
// Function parameters.
|
||||||
|
std::vector<std::string> params;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
params.push_back(
|
||||||
|
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||||
|
if (!is_scalar(x) && !contiguous) {
|
||||||
|
params.push_back(fmt::format(
|
||||||
|
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||||
|
xname));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
params.push_back(fmt::format(
|
||||||
|
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
params.push_back(
|
||||||
|
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||||
|
}
|
||||||
|
params.push_back("IdxT size");
|
||||||
|
|
||||||
|
// Build function signature.
|
||||||
|
if (contiguous) {
|
||||||
|
os += "template <typename IdxT = uint32_t>\n";
|
||||||
|
} else {
|
||||||
|
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||||
|
}
|
||||||
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
|
os += " ";
|
||||||
|
os += params[i];
|
||||||
|
if (i != params.size() - 1) {
|
||||||
|
os += ",\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += ") {\n";
|
||||||
|
|
||||||
|
// Index.
|
||||||
|
os +=
|
||||||
|
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||||
|
" if (index >= size) {\n"
|
||||||
|
" return;\n"
|
||||||
|
" }\n";
|
||||||
|
|
||||||
|
// Read inputs.
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
std::string value;
|
||||||
|
if (is_constant(i)) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
print_constant(ss, x);
|
||||||
|
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||||
|
} else if (is_scalar(x)) {
|
||||||
|
value = fmt::format("{}[0]", xname);
|
||||||
|
} else if (contiguous) {
|
||||||
|
value = fmt::format("{}[index]", xname);
|
||||||
|
} else {
|
||||||
|
std::string index = fmt::format(
|
||||||
|
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||||
|
xname);
|
||||||
|
value = fmt::format("{}[{}]", xname, index);
|
||||||
|
}
|
||||||
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tape.
|
||||||
|
for (const auto& x : tape) {
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
std::string value;
|
||||||
|
if (is_static_cast(x.primitive())) {
|
||||||
|
value = fmt::format(
|
||||||
|
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||||
|
} else {
|
||||||
|
std::ostringstream ss;
|
||||||
|
x.primitive().print(ss);
|
||||||
|
value = ss.str();
|
||||||
|
value += "{}(";
|
||||||
|
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||||
|
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||||
|
}
|
||||||
|
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||||
|
}
|
||||||
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
os += "}\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
constexpr const char* g_jit_includes = R"(
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
)";
|
||||||
|
|
||||||
|
void Compiled::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||||
|
// Build source code.
|
||||||
|
cu::FusedKernelBuilder builder{
|
||||||
|
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||||
|
builder.os +=
|
||||||
|
"namespace mlx::core::cu {\n\n"
|
||||||
|
"namespace cg = cooperative_groups;\n\n";
|
||||||
|
builder.build("_contiguous", true);
|
||||||
|
builder.os += "\n";
|
||||||
|
builder.build("_strided", false);
|
||||||
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
|
// Build kernel names.
|
||||||
|
std::vector<std::string> kernel_names = {
|
||||||
|
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||||
|
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||||
|
};
|
||||||
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||||
|
kernel_names.push_back(
|
||||||
|
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||||
|
}
|
||||||
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides_vec] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Whether to use large index.
|
||||||
|
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||||
|
|
||||||
|
// Put inputs.
|
||||||
|
int strides_index = 1;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (is_constant_(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
mod.append_arg(x);
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
|
mod.append_arg(strides_vec[strides_index++]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put outputs.
|
||||||
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
|
for (auto& x : outputs) {
|
||||||
|
mod.append_arg(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put shape and size.
|
||||||
|
if (!contiguous) {
|
||||||
|
mod.append_arg(shape);
|
||||||
|
}
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch kernel.
|
||||||
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
|
if (contiguous) {
|
||||||
|
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||||
|
} else {
|
||||||
|
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||||
|
}
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
for (const auto& out : outputs) {
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
}
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
|
||||||
|
#include <cuda/atomic>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ void atomic_add(T* out, T val) {
|
||||||
|
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||||
|
ref += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ void atomic_prod(T* out, T val) {
|
||||||
|
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||||
|
T old = ref.load();
|
||||||
|
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ void atomic_max(T* out, T val) {
|
||||||
|
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||||
|
ref.fetch_max(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ void atomic_min(T* out, T val) {
|
||||||
|
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||||
|
ref.fetch_min(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ void atomic_add_general(T* out, T val) {
|
||||||
|
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||||
|
T old = ref.load();
|
||||||
|
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void atomic_add(__half* out, __half val) {
|
||||||
|
atomicAdd(out, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||||
|
#if __CUDA_ARCH__ < 900
|
||||||
|
atomic_add_general(out, val);
|
||||||
|
#else
|
||||||
|
atomicAdd(out, val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||||
|
#if __CUDA_ARCH__ < 800
|
||||||
|
#if CCCL_VERSION >= 2008000
|
||||||
|
atomic_add_general(out, val);
|
||||||
|
#else
|
||||||
|
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||||
|
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
atomicAdd(out, val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||||
|
#define MAX_NDIM 8
|
||||||
|
|
||||||
|
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||||
|
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||||
|
#define WARP_SIZE 32
|
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||||
|
__global__ void gather(
|
||||||
|
const T* src,
|
||||||
|
T* out,
|
||||||
|
LocT size,
|
||||||
|
const __grid_constant__ Shape src_shape,
|
||||||
|
const __grid_constant__ Strides src_strides,
|
||||||
|
int32_t src_ndim,
|
||||||
|
const __grid_constant__ Shape slice_sizes,
|
||||||
|
uint32_t slice_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||||
|
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||||
|
indices_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||||
|
indices_strides) {
|
||||||
|
LocT out_idx = cg::this_grid().thread_rank();
|
||||||
|
if (out_idx >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LocT src_elem = out_idx % slice_size;
|
||||||
|
LocT idx_elem = out_idx / slice_size;
|
||||||
|
|
||||||
|
LocT src_loc =
|
||||||
|
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||||
|
idx_elem,
|
||||||
|
indices_shape.data() + i * IDX_NDIM,
|
||||||
|
indices_strides.data() + i * IDX_NDIM);
|
||||||
|
int32_t axis = axes[i];
|
||||||
|
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||||
|
src_loc += idx_val * src_strides[axis];
|
||||||
|
}
|
||||||
|
|
||||||
|
out[out_idx] = src[src_loc];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename IdxT,
|
||||||
|
int NDIM,
|
||||||
|
bool SrcC,
|
||||||
|
bool IdxC,
|
||||||
|
typename LocT>
|
||||||
|
__global__ void gather_axis(
|
||||||
|
const T* src,
|
||||||
|
const IdxT* indices,
|
||||||
|
T* out,
|
||||||
|
LocT idx_size_pre,
|
||||||
|
LocT idx_size_axis,
|
||||||
|
LocT idx_size_post,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||||
|
int32_t axis,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t src_stride_axis,
|
||||||
|
int64_t idx_stride_axis) {
|
||||||
|
LocT index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||||
|
|
||||||
|
LocT elem_idx = z * idx_size_post;
|
||||||
|
|
||||||
|
LocT idx_loc = y * idx_stride_axis;
|
||||||
|
if constexpr (IdxC) {
|
||||||
|
idx_loc += elem_idx * idx_size_axis + x;
|
||||||
|
} else {
|
||||||
|
idx_loc +=
|
||||||
|
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||||
|
|
||||||
|
LocT src_loc = idx_val * src_stride_axis;
|
||||||
|
if constexpr (SrcC) {
|
||||||
|
src_loc += elem_idx * axis_size + x;
|
||||||
|
} else {
|
||||||
|
src_loc +=
|
||||||
|
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||||
|
|
||||||
|
out[out_idx] = src[src_loc];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <cuda/std/tuple>
|
||||||
|
#include <cuda/std/type_traits>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||||
|
// calculated with:
|
||||||
|
// index = x * dim1 * dim2 + y * dim2 + z
|
||||||
|
template <typename T>
|
||||||
|
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||||
|
index_to_dims(T index, T dim1, T dim2) {
|
||||||
|
T x = index / (dim1 * dim2);
|
||||||
|
T y = (index % (dim1 * dim2)) / dim2;
|
||||||
|
T z = index % dim2;
|
||||||
|
return cuda::std::make_tuple(x, y, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get absolute index from possible negative index.
|
||||||
|
template <typename IdxT>
|
||||||
|
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||||
|
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||||
|
return idx;
|
||||||
|
} else {
|
||||||
|
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename IdxT,
|
||||||
|
typename Op,
|
||||||
|
int NIDX,
|
||||||
|
int IDX_NDIM,
|
||||||
|
typename LocT>
|
||||||
|
__global__ void scatter(
|
||||||
|
const T* upd,
|
||||||
|
T* out,
|
||||||
|
LocT size,
|
||||||
|
const __grid_constant__ Shape upd_shape,
|
||||||
|
const __grid_constant__ Strides upd_strides,
|
||||||
|
int32_t upd_ndim,
|
||||||
|
LocT upd_post_idx_size,
|
||||||
|
const __grid_constant__ Shape out_shape,
|
||||||
|
const __grid_constant__ Strides out_strides,
|
||||||
|
int32_t out_ndim,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||||
|
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||||
|
indices_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||||
|
indices_strides) {
|
||||||
|
LocT upd_idx = cg::this_grid().thread_rank();
|
||||||
|
if (upd_idx >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||||
|
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||||
|
|
||||||
|
LocT out_idx = elem_to_loc(
|
||||||
|
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||||
|
idx_elem,
|
||||||
|
indices_shape.data() + i * IDX_NDIM,
|
||||||
|
indices_strides.data() + i * IDX_NDIM);
|
||||||
|
int32_t axis = axes[i];
|
||||||
|
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||||
|
out_idx += idx_val * out_strides[axis];
|
||||||
|
}
|
||||||
|
|
||||||
|
LocT upd_loc = elem_to_loc(
|
||||||
|
out_elem + idx_elem * upd_post_idx_size,
|
||||||
|
upd_shape.data(),
|
||||||
|
upd_strides.data(),
|
||||||
|
upd_ndim);
|
||||||
|
|
||||||
|
Op{}(out + out_idx, upd[upd_loc]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename IdxT,
|
||||||
|
typename Op,
|
||||||
|
int NDIM,
|
||||||
|
bool UpdC,
|
||||||
|
bool IdxC,
|
||||||
|
typename LocT>
|
||||||
|
__global__ void scatter_axis(
|
||||||
|
const T* upd,
|
||||||
|
const IdxT* indices,
|
||||||
|
T* out,
|
||||||
|
LocT idx_size_pre,
|
||||||
|
LocT idx_size_axis,
|
||||||
|
LocT idx_size_post,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||||
|
int32_t axis,
|
||||||
|
int32_t axis_size,
|
||||||
|
int64_t upd_stride_axis,
|
||||||
|
int64_t idx_stride_axis) {
|
||||||
|
LocT index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||||
|
|
||||||
|
LocT elem_idx = z * idx_size_post;
|
||||||
|
|
||||||
|
LocT idx_loc = y * idx_stride_axis;
|
||||||
|
if constexpr (IdxC) {
|
||||||
|
idx_loc += elem_idx * idx_size_axis + x;
|
||||||
|
} else {
|
||||||
|
idx_loc +=
|
||||||
|
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||||
|
|
||||||
|
LocT upd_loc = y * upd_stride_axis;
|
||||||
|
if constexpr (UpdC) {
|
||||||
|
upd_loc += elem_idx * idx_size_axis + x;
|
||||||
|
} else {
|
||||||
|
upd_loc +=
|
||||||
|
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||||
|
|
||||||
|
Op{}(out + out_idx, upd[upd_loc]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
struct ScatterAssign {
|
||||||
|
template <typename T>
|
||||||
|
__device__ void operator()(T* out, T val) const {
|
||||||
|
*out = val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ScatterSum {
|
||||||
|
template <typename T>
|
||||||
|
__device__ void operator()(T* out, T val) const {
|
||||||
|
atomic_add(out, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ScatterProd {
|
||||||
|
template <typename T>
|
||||||
|
__device__ void operator()(T* out, T val) const {
|
||||||
|
atomic_prod(out, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ScatterMax {
|
||||||
|
template <typename T>
|
||||||
|
__device__ void operator()(T* out, T val) const {
|
||||||
|
atomic_max(out, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ScatterMin {
|
||||||
|
template <typename T>
|
||||||
|
__device__ void operator()(T* out, T val) const {
|
||||||
|
atomic_min(out, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
@ -8,6 +8,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
@ -21,14 +23,8 @@ namespace mlx::core::cu {
|
|||||||
// CUDA kernel utils
|
// CUDA kernel utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
|
||||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
|
||||||
#define WARP_SIZE 32
|
|
||||||
|
|
||||||
// To pass shape/strides to kernels via constant memory, their size must be
|
// To pass shape/strides to kernels via constant memory, their size must be
|
||||||
// known at compile time.
|
// known at compile time.
|
||||||
#define MAX_NDIM 8
|
|
||||||
|
|
||||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||||
|
|
420
mlx/backend/cuda/indexing.cpp
Normal file
420
mlx/backend/cuda/indexing.cpp
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include "cuda_jit_sources.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
||||||
|
|
||||||
|
void append_indices_arg(
|
||||||
|
cu::JitModule& mod,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
int nidx,
|
||||||
|
int idx_ndim) {
|
||||||
|
std::vector<const void*> indices(nidx);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
indices[i] = inputs[i + 1].data<void>();
|
||||||
|
}
|
||||||
|
mod.append_arg(std::move(indices));
|
||||||
|
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
std::copy_n(
|
||||||
|
inputs[i + 1].shape().begin(),
|
||||||
|
idx_ndim,
|
||||||
|
indices_shape.data() + i * idx_ndim);
|
||||||
|
}
|
||||||
|
mod.append_arg(std::move(indices_shape));
|
||||||
|
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
std::copy_n(
|
||||||
|
inputs[i + 1].strides().begin(),
|
||||||
|
idx_ndim,
|
||||||
|
indices_strides.data() + i * idx_ndim);
|
||||||
|
}
|
||||||
|
mod.append_arg(std::move(indices_strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||||
|
assert(inputs.size() > 0);
|
||||||
|
const auto& src = inputs[0];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int nidx = inputs.size() - 1;
|
||||||
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||||
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||||
|
|
||||||
|
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||||
|
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||||
|
|
||||||
|
uint32_t slice_size = std::accumulate(
|
||||||
|
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||||
|
|
||||||
|
std::string module_name = fmt::format(
|
||||||
|
"gather_{}_{}_{}",
|
||||||
|
dtype_to_string(out.dtype()),
|
||||||
|
dtype_to_string(idx_dtype),
|
||||||
|
nidx);
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
|
std::vector<std::string> kernel_names;
|
||||||
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
|
for (int large = 0; large <= 1; ++large) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx_dtype),
|
||||||
|
nidx,
|
||||||
|
ndim,
|
||||||
|
large ? "int64_t" : "uint32_t"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
mod.append_arg(src);
|
||||||
|
mod.append_arg(out);
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(out.size());
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(out.size());
|
||||||
|
}
|
||||||
|
mod.append_ndim_arg(src.shape());
|
||||||
|
mod.append_ndim_arg(src.strides());
|
||||||
|
mod.append_arg<int32_t>(src.ndim());
|
||||||
|
mod.append_ndim_arg(slice_sizes_);
|
||||||
|
mod.append_arg(slice_size);
|
||||||
|
mod.append_arg(axes_);
|
||||||
|
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
|
std::string kernel_name = fmt::format(
|
||||||
|
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx_dtype),
|
||||||
|
nidx,
|
||||||
|
idx_ndim,
|
||||||
|
large ? "int64_t" : "uint32_t");
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
mod.launch_kernel(stream, kernel_name, out, large);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
auto& upd = inputs.back();
|
||||||
|
|
||||||
|
// Copy src into out.
|
||||||
|
CopyType copy_type;
|
||||||
|
if (inputs[0].data_size() == 1) {
|
||||||
|
copy_type = CopyType::Scalar;
|
||||||
|
} else if (inputs[0].flags().row_contiguous) {
|
||||||
|
copy_type = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
copy_type = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu(inputs[0], out, copy_type);
|
||||||
|
|
||||||
|
// Empty update.
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int nidx = axes_.size();
|
||||||
|
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||||
|
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||||
|
|
||||||
|
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||||
|
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||||
|
|
||||||
|
uint32_t upd_post_idx_size = std::accumulate(
|
||||||
|
upd.shape().begin() + idx_ndim,
|
||||||
|
upd.shape().end(),
|
||||||
|
1,
|
||||||
|
std::multiplies<uint32_t>());
|
||||||
|
|
||||||
|
const char* op = g_scatter_ops[reduce_type_];
|
||||||
|
std::string module_name = fmt::format(
|
||||||
|
"scatter_{}_{}_{}_{}",
|
||||||
|
dtype_to_string(out.dtype()),
|
||||||
|
dtype_to_string(idx_dtype),
|
||||||
|
op,
|
||||||
|
nidx);
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
|
std::vector<std::string> kernel_names;
|
||||||
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
|
for (int large = 0; large <= 1; ++large) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx_dtype),
|
||||||
|
op,
|
||||||
|
nidx,
|
||||||
|
ndim,
|
||||||
|
large ? "int64_t" : "uint32_t"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
mod.append_arg(upd);
|
||||||
|
mod.append_arg(out);
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(upd.size());
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(upd.size());
|
||||||
|
}
|
||||||
|
mod.append_ndim_arg(upd.shape());
|
||||||
|
mod.append_ndim_arg(upd.strides());
|
||||||
|
mod.append_arg<int32_t>(upd.ndim());
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||||
|
}
|
||||||
|
mod.append_ndim_arg(out.shape());
|
||||||
|
mod.append_ndim_arg(out.strides());
|
||||||
|
mod.append_arg<int32_t>(out.ndim());
|
||||||
|
mod.append_arg(axes_);
|
||||||
|
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
|
std::string kernel_name = fmt::format(
|
||||||
|
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx_dtype),
|
||||||
|
op,
|
||||||
|
nidx,
|
||||||
|
idx_ndim,
|
||||||
|
large ? "int64_t" : "uint32_t");
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
mod.launch_kernel(stream, kernel_name, upd, large);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
const auto& src = inputs[0];
|
||||||
|
const auto& idx = inputs[1];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||||
|
|
||||||
|
std::string module_name = fmt::format(
|
||||||
|
"gather_axis_{}_{}",
|
||||||
|
dtype_to_string(out.dtype()),
|
||||||
|
dtype_to_string(idx.dtype()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
|
std::vector<std::string> kernel_names;
|
||||||
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
|
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||||
|
for (int large = 0; large <= 1; ++large) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx.dtype()),
|
||||||
|
ndim,
|
||||||
|
contiguous & 1 ? true : false,
|
||||||
|
contiguous & 2 ? true : false,
|
||||||
|
large ? "int64_t" : "uint32_t"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
size_t idx_size_pre = 1;
|
||||||
|
size_t idx_size_post = 1;
|
||||||
|
for (int i = 0; i < axis_; ++i) {
|
||||||
|
idx_size_pre *= idx.shape(i);
|
||||||
|
}
|
||||||
|
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||||
|
idx_size_post *= idx.shape(i);
|
||||||
|
}
|
||||||
|
size_t idx_size_axis = idx.shape(axis_);
|
||||||
|
|
||||||
|
mod.append_arg(src);
|
||||||
|
mod.append_arg(idx);
|
||||||
|
mod.append_arg(out);
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(idx_size_pre);
|
||||||
|
mod.append_arg<int64_t>(idx_size_axis);
|
||||||
|
mod.append_arg<int64_t>(idx_size_post);
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(idx_size_pre);
|
||||||
|
mod.append_arg<uint32_t>(idx_size_axis);
|
||||||
|
mod.append_arg<uint32_t>(idx_size_post);
|
||||||
|
}
|
||||||
|
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||||
|
mod.append_arg(remove_index(src.strides(), axis_));
|
||||||
|
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||||
|
mod.append_arg<int32_t>(axis_);
|
||||||
|
mod.append_arg(src.shape(axis_));
|
||||||
|
mod.append_arg(src.strides(axis_));
|
||||||
|
mod.append_arg(idx.strides(axis_));
|
||||||
|
|
||||||
|
std::string kernel_name = fmt::format(
|
||||||
|
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx.dtype()),
|
||||||
|
src.ndim() - 1,
|
||||||
|
src.flags().row_contiguous,
|
||||||
|
idx.flags().row_contiguous,
|
||||||
|
large ? "int64_t" : "uint32_t");
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
||||||
|
assert(inputs.size() > 2);
|
||||||
|
const auto& src = inputs[0];
|
||||||
|
const auto& idx = inputs[1];
|
||||||
|
const auto& upd = inputs[2];
|
||||||
|
|
||||||
|
// Copy src into out.
|
||||||
|
CopyType copy_type;
|
||||||
|
if (src.data_size() == 1) {
|
||||||
|
copy_type = CopyType::Scalar;
|
||||||
|
} else if (src.flags().row_contiguous) {
|
||||||
|
copy_type = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
copy_type = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu(src, out, copy_type);
|
||||||
|
|
||||||
|
// Empty update.
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||||
|
|
||||||
|
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||||
|
std::string module_name = fmt::format(
|
||||||
|
"scatter_axis_{}_{}_{}",
|
||||||
|
dtype_to_string(out.dtype()),
|
||||||
|
dtype_to_string(idx.dtype()),
|
||||||
|
op);
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
|
std::vector<std::string> kernel_names;
|
||||||
|
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||||
|
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||||
|
for (int large = 0; large <= 1; ++large) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx.dtype()),
|
||||||
|
op,
|
||||||
|
ndim,
|
||||||
|
contiguous & 1 ? true : false,
|
||||||
|
contiguous & 2 ? true : false,
|
||||||
|
large ? "int64_t" : "uint32_t"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
size_t idx_size_pre = 1;
|
||||||
|
size_t idx_size_post = 1;
|
||||||
|
for (int i = 0; i < axis_; ++i) {
|
||||||
|
idx_size_pre *= idx.shape(i);
|
||||||
|
}
|
||||||
|
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||||
|
idx_size_post *= idx.shape(i);
|
||||||
|
}
|
||||||
|
size_t idx_size_axis = idx.shape(axis_);
|
||||||
|
|
||||||
|
mod.append_arg(upd);
|
||||||
|
mod.append_arg(idx);
|
||||||
|
mod.append_arg(out);
|
||||||
|
if (large) {
|
||||||
|
mod.append_arg<int64_t>(idx_size_pre);
|
||||||
|
mod.append_arg<int64_t>(idx_size_axis);
|
||||||
|
mod.append_arg<int64_t>(idx_size_post);
|
||||||
|
} else {
|
||||||
|
mod.append_arg<uint32_t>(idx_size_pre);
|
||||||
|
mod.append_arg<uint32_t>(idx_size_axis);
|
||||||
|
mod.append_arg<uint32_t>(idx_size_post);
|
||||||
|
}
|
||||||
|
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||||
|
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||||
|
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||||
|
mod.append_arg<int32_t>(axis_);
|
||||||
|
mod.append_arg(out.shape(axis_));
|
||||||
|
mod.append_arg(upd.strides(axis_));
|
||||||
|
mod.append_arg(idx.strides(axis_));
|
||||||
|
|
||||||
|
std::string kernel_name = fmt::format(
|
||||||
|
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||||
|
dtype_to_cuda_type(out.dtype()),
|
||||||
|
dtype_to_cuda_type(idx.dtype()),
|
||||||
|
op,
|
||||||
|
idx.ndim() - 1,
|
||||||
|
upd.flags().row_contiguous,
|
||||||
|
idx.flags().row_contiguous,
|
||||||
|
large ? "int64_t" : "uint32_t");
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
346
mlx/backend/cuda/jit_module.cpp
Normal file
346
mlx/backend/cuda/jit_module.cpp
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
#include "cuda_jit_sources.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvrtc.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_nvrtc_error(const char* name, nvrtcResult err) {
|
||||||
|
if (err != NVRTC_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}", name, nvrtcGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cu_error(const char* name, CUresult err) {
|
||||||
|
if (err != CUDA_SUCCESS) {
|
||||||
|
const char* err_str = "Unknown error";
|
||||||
|
cuGetErrorString(err, &err_str);
|
||||||
|
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the location of the CUDA toolkit.
|
||||||
|
const char* cuda_home() {
|
||||||
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
|
if (home) {
|
||||||
|
return home;
|
||||||
|
}
|
||||||
|
home = std::getenv("CUDA_PATH");
|
||||||
|
if (home) {
|
||||||
|
return home;
|
||||||
|
}
|
||||||
|
#if defined(__linux__)
|
||||||
|
home = "/usr/local/cuda";
|
||||||
|
if (std::filesystem::exists(home)) {
|
||||||
|
return home;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the cache directory for storing compiled results.
|
||||||
|
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||||
|
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
|
if (!std::filesystem::is_directory(path)) {
|
||||||
|
std::error_code error;
|
||||||
|
if (!std::filesystem::create_directories(path, error)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*result = path;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
|
bool read_cached_ptx(
|
||||||
|
const std::filesystem::path& cache_dir,
|
||||||
|
const std::string& module_name,
|
||||||
|
std::vector<char>* ptx,
|
||||||
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
|
std::error_code error;
|
||||||
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
|
if (error) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::ifstream ptx_file(ptx_path, std::ios::binary);
|
||||||
|
if (!ptx_file.good()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ptx->resize(ptx_size);
|
||||||
|
ptx_file.read(ptx->data(), ptx_size);
|
||||||
|
|
||||||
|
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(txt_file, line)) {
|
||||||
|
auto tab = line.find('\t');
|
||||||
|
if (tab != std::string::npos) {
|
||||||
|
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|.
|
||||||
|
void write_cached_ptx(
|
||||||
|
const std::filesystem::path& cache_dir,
|
||||||
|
const std::string& module_name,
|
||||||
|
const std::vector<char>& ptx,
|
||||||
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
|
if (!ptx.empty()) {
|
||||||
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
|
}
|
||||||
|
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||||
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
|
txt_file << name << "\t" << mangled << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||||
|
inline bool version_lower_equal(Device& device, int major, int minor) {
|
||||||
|
if (device.compute_capability_major() < major) {
|
||||||
|
return true;
|
||||||
|
} else if (device.compute_capability_major() == major) {
|
||||||
|
return device.compute_capability_minor() <= minor;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return whether NVRTC supports compiling to |device|'s SASS code.
|
||||||
|
bool compiler_supports_device_sass(Device& device) {
|
||||||
|
int nvrtc_major, nvrtc_minor;
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||||
|
if (nvrtc_major < 9) {
|
||||||
|
return false;
|
||||||
|
} else if (nvrtc_major == 9) {
|
||||||
|
return version_lower_equal(device, 7, 2);
|
||||||
|
} else if (nvrtc_major == 10) {
|
||||||
|
return version_lower_equal(device, 7, 5);
|
||||||
|
} else if (nvrtc_major == 11 && nvrtc_minor == 0) {
|
||||||
|
return version_lower_equal(device, 8, 0);
|
||||||
|
} else if (nvrtc_major == 11 && nvrtc_minor < 8) {
|
||||||
|
return version_lower_equal(device, 8, 6);
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||||
|
|
||||||
|
constexpr const char* g_include_names[] = {
|
||||||
|
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||||
|
INCLUDE_PREFIX "binary_ops.cuh",
|
||||||
|
INCLUDE_PREFIX "cast_op.cuh",
|
||||||
|
INCLUDE_PREFIX "config.h",
|
||||||
|
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||||
|
INCLUDE_PREFIX "fp16_math.cuh",
|
||||||
|
INCLUDE_PREFIX "indexing.cuh",
|
||||||
|
INCLUDE_PREFIX "scatter_ops.cuh",
|
||||||
|
INCLUDE_PREFIX "unary_ops.cuh",
|
||||||
|
INCLUDE_PREFIX "utils.cuh",
|
||||||
|
};
|
||||||
|
|
||||||
|
#undef INCLUDE_PREFIX
|
||||||
|
|
||||||
|
constexpr const char* g_headers[] = {
|
||||||
|
jit_source_atomic_ops,
|
||||||
|
jit_source_binary_ops,
|
||||||
|
jit_source_cast_op,
|
||||||
|
jit_source_config,
|
||||||
|
jit_source_cucomplex_math,
|
||||||
|
jit_source_fp16_math,
|
||||||
|
jit_source_indexing,
|
||||||
|
jit_source_scatter_ops,
|
||||||
|
jit_source_unary_ops,
|
||||||
|
jit_source_utils,
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
JitModule::JitModule(
|
||||||
|
Device& device,
|
||||||
|
const std::string& module_name,
|
||||||
|
const KernelBuilder& builder) {
|
||||||
|
// Check cache.
|
||||||
|
std::filesystem::path cache_dir;
|
||||||
|
std::vector<char> ptx;
|
||||||
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
|
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||||
|
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||||
|
// Create program.
|
||||||
|
auto [source_code, kernel_names] = builder();
|
||||||
|
nvrtcProgram prog;
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||||
|
&prog,
|
||||||
|
source_code.c_str(),
|
||||||
|
(module_name + ".cu").c_str(),
|
||||||
|
std::size(g_headers),
|
||||||
|
g_headers,
|
||||||
|
g_include_names));
|
||||||
|
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||||
|
&prog,
|
||||||
|
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||||
|
for (const auto& name : kernel_names) {
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile program.
|
||||||
|
bool use_sass = compiler_supports_device_sass(device);
|
||||||
|
std::string compute = fmt::format(
|
||||||
|
"--gpu-architecture={}_{}{}",
|
||||||
|
use_sass ? "sm" : "compute",
|
||||||
|
device.compute_capability_major(),
|
||||||
|
device.compute_capability_minor());
|
||||||
|
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
||||||
|
const char* args[] = {compute.c_str(), include.c_str()};
|
||||||
|
nvrtcResult compile_result =
|
||||||
|
nvrtcCompileProgram(prog, std::size(args), args);
|
||||||
|
if (compile_result != NVRTC_SUCCESS) {
|
||||||
|
size_t log_size;
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||||
|
std::vector<char> log(log_size + 1, 0);
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get mangled names of kernel names.
|
||||||
|
for (const auto& name : kernel_names) {
|
||||||
|
const char* mangled;
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||||
|
ptx_kernels.emplace_back(name, mangled);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get ptx data.
|
||||||
|
size_t ptx_size;
|
||||||
|
if (use_sass) {
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||||
|
} else {
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||||
|
}
|
||||||
|
ptx.resize(ptx_size, 0);
|
||||||
|
if (use_sass) {
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||||
|
} else {
|
||||||
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
|
}
|
||||||
|
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load module.
|
||||||
|
char jit_log[4089] = {};
|
||||||
|
CUjit_option options[] = {
|
||||||
|
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||||
|
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||||
|
CUresult jit_result = cuModuleLoadDataEx(
|
||||||
|
&module_, ptx.data(), std::size(options), options, values);
|
||||||
|
if (jit_result != CUDA_SUCCESS) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load kernels.
|
||||||
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
|
CUfunction kernel;
|
||||||
|
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||||
|
kernels_[name] = kernel;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
JitModule::~JitModule() {
|
||||||
|
CHECK_CU_ERROR(cuModuleUnload(module_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void JitModule::launch_kernel(
|
||||||
|
CUstream stream,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread) {
|
||||||
|
CUfunction kernel = get_kernel(kernel_name);
|
||||||
|
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||||
|
int _, block_dim;
|
||||||
|
CHECK_CU_ERROR(
|
||||||
|
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||||
|
if (block_dim > nthreads) {
|
||||||
|
block_dim = nthreads;
|
||||||
|
}
|
||||||
|
Dims num_blocks{1, 1, 1};
|
||||||
|
if (large) {
|
||||||
|
num_blocks =
|
||||||
|
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
|
||||||
|
std::get<0>(num_blocks) =
|
||||||
|
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
|
||||||
|
} else {
|
||||||
|
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
|
||||||
|
}
|
||||||
|
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
void JitModule::launch_kernel(
|
||||||
|
CUstream stream,
|
||||||
|
CUfunction kernel,
|
||||||
|
Dims num_blocks,
|
||||||
|
Dims block_dims) {
|
||||||
|
CHECK_CU_ERROR(cuLaunchKernel(
|
||||||
|
kernel,
|
||||||
|
std::get<0>(num_blocks),
|
||||||
|
std::get<1>(num_blocks),
|
||||||
|
std::get<2>(num_blocks),
|
||||||
|
std::get<0>(block_dims),
|
||||||
|
std::get<1>(block_dims),
|
||||||
|
std::get<2>(block_dims),
|
||||||
|
0,
|
||||||
|
stream,
|
||||||
|
args_.data(),
|
||||||
|
nullptr));
|
||||||
|
args_.clear();
|
||||||
|
storage_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||||
|
auto it = kernels_.find(kernel_name);
|
||||||
|
if (it == kernels_.end()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("There is no kernel named {}.", kernel_name));
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void JitModule::append_ptr_arg(const void* v) {
|
||||||
|
args_.push_back(const_cast<void*>(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
JitModule& get_jit_module(
|
||||||
|
const mlx::core::Device& device,
|
||||||
|
const std::string& name,
|
||||||
|
const KernelBuilder& builder) {
|
||||||
|
static std::unordered_map<std::string, JitModule> map;
|
||||||
|
auto it = map.find(name);
|
||||||
|
if (it == map.end()) {
|
||||||
|
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
113
mlx/backend/cuda/jit_module.h
Normal file
113
mlx/backend/cuda/jit_module.h
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
class Device;
|
||||||
|
|
||||||
|
using KernelBuilderResult = std::pair<
|
||||||
|
/* source code */ std::string,
|
||||||
|
/* kernel names */ std::vector<std::string>>;
|
||||||
|
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||||
|
|
||||||
|
class JitModule {
|
||||||
|
public:
|
||||||
|
JitModule(
|
||||||
|
Device& device,
|
||||||
|
const std::string& module_name,
|
||||||
|
const KernelBuilder& builder);
|
||||||
|
~JitModule();
|
||||||
|
|
||||||
|
JitModule(const JitModule&) = delete;
|
||||||
|
JitModule& operator=(const JitModule&) = delete;
|
||||||
|
|
||||||
|
void append_arg(const array& a) {
|
||||||
|
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void append_arg(T val) {
|
||||||
|
storage_.emplace_back(val);
|
||||||
|
append_ptr_arg(&storage_.back());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void append_arg(std::vector<T> vec) {
|
||||||
|
if (vec.empty()) {
|
||||||
|
// The nullptr can not be used as arg, pass something not null.
|
||||||
|
append_arg(std::monostate{});
|
||||||
|
} else {
|
||||||
|
append_ptr_arg(vec.data());
|
||||||
|
storage_.emplace_back(std::move(vec));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the arg is copied to an array with size of NDIM.
|
||||||
|
template <size_t NDIM = MAX_NDIM, typename T>
|
||||||
|
void append_ndim_arg(const std::vector<T>& vec) {
|
||||||
|
if (vec.size() > NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||||
|
}
|
||||||
|
std::vector<T> copied(NDIM);
|
||||||
|
std::copy(vec.begin(), vec.end(), copied.data());
|
||||||
|
append_arg(std::move(copied));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch kernel with |kernel_name| that each thread works on
|
||||||
|
// |work_per_thread| elements of |arr|.
|
||||||
|
void launch_kernel(
|
||||||
|
CUstream stream,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread = 1);
|
||||||
|
|
||||||
|
void launch_kernel(
|
||||||
|
CUstream stream,
|
||||||
|
CUfunction kernel,
|
||||||
|
Dims num_blocks,
|
||||||
|
Dims block_dims);
|
||||||
|
|
||||||
|
CUfunction get_kernel(const std::string& kernel_name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void append_ptr_arg(const void* v);
|
||||||
|
|
||||||
|
CUmodule module_{nullptr};
|
||||||
|
std::unordered_map<std::string, CUfunction> kernels_;
|
||||||
|
std::vector<void*> args_;
|
||||||
|
|
||||||
|
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||||
|
// temporary values untill kernel is launched.
|
||||||
|
using Arg = std::variant<
|
||||||
|
std::monostate,
|
||||||
|
CUdeviceptr,
|
||||||
|
int32_t,
|
||||||
|
uint32_t,
|
||||||
|
int64_t,
|
||||||
|
std::vector<const void*>,
|
||||||
|
std::vector<int32_t>,
|
||||||
|
std::vector<int64_t>>;
|
||||||
|
std::deque<Arg> storage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
JitModule& get_jit_module(
|
||||||
|
const mlx::core::Device& device,
|
||||||
|
const std::string& name,
|
||||||
|
const KernelBuilder& builder);
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -1,13 +1,13 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
// from backend/cuda/device/utils.cuh is that the latter file only include
|
||||||
// device-only code.
|
// device-only code.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/arange.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
@ -73,14 +73,11 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
|
|
||||||
NO_GPU(ArgPartition)
|
NO_GPU(ArgPartition)
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU_MULTI(Compiled)
|
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU_MULTI(DivMod)
|
NO_GPU_MULTI(DivMod)
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
NO_GPU(Gather)
|
|
||||||
NO_GPU(GatherAxis)
|
|
||||||
NO_GPU(GatherMM)
|
NO_GPU(GatherMM)
|
||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
NO_GPU(Hadamard)
|
NO_GPU(Hadamard)
|
||||||
@ -90,8 +87,6 @@ NO_GPU(Partition)
|
|||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
NO_GPU(QuantizedMatmul)
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
|
||||||
NO_GPU(ScatterAxis)
|
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||||
|
if (dtype == float16) {
|
||||||
|
return "__half";
|
||||||
|
}
|
||||||
|
if (dtype == bfloat16) {
|
||||||
|
return "__nv_bfloat16";
|
||||||
|
}
|
||||||
|
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||||
|
if (dtype == DTYPE) { \
|
||||||
|
return #CPP_TYPE; \
|
||||||
|
}
|
||||||
|
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||||
|
#undef SPECIALIZE_DtypeToString
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -12,6 +12,8 @@ namespace cu {
|
|||||||
class Device;
|
class Device;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Dtype;
|
||||||
|
|
||||||
// Cuda stream managed with RAII.
|
// Cuda stream managed with RAII.
|
||||||
class CudaStream {
|
class CudaStream {
|
||||||
public:
|
public:
|
||||||
@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err);
|
|||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
// Convert Dtype to CUDA C++ types.
|
||||||
|
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user