mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Compare commits
3 Commits
c2dd81a8aa
...
918761a25a
Author | SHA1 | Date | |
---|---|---|---|
![]() |
918761a25a | ||
![]() |
a4fc671d3e | ||
![]() |
f5f65ef48c |
@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
@ -18,6 +19,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
@ -28,6 +30,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
@ -37,6 +40,24 @@ target_sources(
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/cuda_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||
add_dependencies(mlx cuda_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
@ -87,6 +108,9 @@ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
# Use cublasLt.
|
||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@ -0,0 +1,150 @@
|
||||
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||
#
|
||||
# Copyright 2020 Sivachandran Paramasivam
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# Function to wrap a given string into multiple lines at the given column
|
||||
# position.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * VARIABLE - The name of the CMake variable holding the string.
|
||||
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||
function(WRAP_STRING)
|
||||
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||
|
||||
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||
math(EXPR offset "0")
|
||||
|
||||
while(stringLength GREATER 0)
|
||||
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||
else()
|
||||
math(EXPR length "${stringLength}")
|
||||
endif()
|
||||
|
||||
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||
set(lines "${lines}\n ${line}")
|
||||
|
||||
math(EXPR stringLength "${stringLength} - ${length}")
|
||||
math(EXPR offset "${offset} + ${length}")
|
||||
endwhile()
|
||||
|
||||
set(${WRAP_STRING_VARIABLE}
|
||||
"${lines}"
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||
# The header file will contain a byte array and integer variable holding the
|
||||
# size of the array.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||
# the header file.
|
||||
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||
# "_SIZE" will be append to this name and will be used a variable name for
|
||||
# size variable.
|
||||
# * HEADER_FILE - The path of header file.
|
||||
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||
# array.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||
function(BIN2H)
|
||||
set(options APPEND NULL_TERMINATE)
|
||||
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||
set(multiValueArgs SOURCE_FILES)
|
||||
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(arrayDefinition "")
|
||||
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||
# get filename without extension
|
||||
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||
# convert the filename to a valid C identifier
|
||||
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||
|
||||
# reads source file contents as hex string
|
||||
file(READ ${SOURCE_FILE} hexString HEX)
|
||||
|
||||
# append null
|
||||
if(BIN2H_NULL_TERMINATE)
|
||||
string(APPEND hexString "00")
|
||||
endif()
|
||||
|
||||
# wraps the hex string into multiple lines
|
||||
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||
|
||||
# strip the © in source code
|
||||
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||
${arrayValues})
|
||||
|
||||
# make a full variable name for the array
|
||||
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||
|
||||
# declares byte array and the length variables
|
||||
string(APPEND arrayDefinition
|
||||
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||
endforeach()
|
||||
|
||||
# add namespace wrapper if defined
|
||||
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||
endif()
|
||||
|
||||
set(arrayIncludes "#pragma once")
|
||||
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||
|
||||
if(BIN2H_APPEND)
|
||||
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
else()
|
||||
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ----------------------------- CLI args -----------------------------
|
||||
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||
endforeach()
|
||||
|
||||
bin2h(
|
||||
SOURCE_FILES
|
||||
${MLX_JIT_SOURCES_ABS}
|
||||
NULL_TERMINATE
|
||||
VARIABLE_NAME
|
||||
"jit_source"
|
||||
HEADER_NAMESPACE
|
||||
"mlx::core"
|
||||
HEADER_FILE
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
228
mlx/backend/cuda/compiled.cpp
Normal file
228
mlx/backend/cuda/compiled.cpp
Normal file
@ -0,0 +1,228 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
struct FusedKernelBuilder {
|
||||
std::string os;
|
||||
const std::string& kernel_name;
|
||||
const std::vector<array>& inputs;
|
||||
const std::vector<array>& outputs;
|
||||
const std::vector<array>& tape;
|
||||
const std::function<bool(size_t)>& is_constant;
|
||||
|
||||
void build(const char* name, bool contiguous) {
|
||||
NodeNamer namer;
|
||||
|
||||
// Function parameters.
|
||||
std::vector<std::string> params;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
params.push_back(
|
||||
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
params.push_back(fmt::format(
|
||||
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||
xname));
|
||||
}
|
||||
}
|
||||
for (const auto& x : outputs) {
|
||||
params.push_back(fmt::format(
|
||||
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||
}
|
||||
if (!contiguous) {
|
||||
params.push_back(
|
||||
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||
}
|
||||
params.push_back("IdxT size");
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t>\n";
|
||||
} else {
|
||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
os += " ";
|
||||
os += params[i];
|
||||
if (i != params.size() - 1) {
|
||||
os += ",\n";
|
||||
}
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_constant(i)) {
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("{}[index]", xname);
|
||||
} else {
|
||||
std::string index = fmt::format(
|
||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
||||
xname);
|
||||
value = fmt::format("{}[{}]", xname, index);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
for (const auto& x : tape) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_static_cast(x.primitive())) {
|
||||
value = fmt::format(
|
||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
x.primitive().print(ss);
|
||||
value = ss.str();
|
||||
value += "{}(";
|
||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||
builder.os +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
builder.build("_contiguous", true);
|
||||
builder.os += "\n";
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names = {
|
||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
||||
};
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
||||
kernel_names.push_back(
|
||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
||||
}
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides_vec] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
mod.append_arg(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
mod.append_arg(strides_vec[strides_index++]);
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
mod.append_arg(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
mod.append_arg(shape);
|
||||
}
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
||||
} else {
|
||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, outputs[0], large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda/std/array>
|
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
@ -2,8 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
@ -21,14 +23,8 @@ namespace mlx::core::cu {
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
#define MAX_NDIM 8
|
||||
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
340
mlx/backend/cuda/jit_module.cpp
Normal file
340
mlx/backend/cuda/jit_module.cpp
Normal file
@ -0,0 +1,340 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd))
|
||||
|
||||
void check_nvrtc_error(const char* name, nvrtcResult err) {
|
||||
if (err != NVRTC_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}", name, nvrtcGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
|
||||
|
||||
void check_cu_error(const char* name, CUresult err) {
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* err_str = "Unknown error";
|
||||
cuGetErrorString(err, &err_str);
|
||||
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the CUDA toolkit.
|
||||
const char* cuda_home() {
|
||||
const char* home = std::getenv("CUDA_HOME");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
home = std::getenv("CUDA_PATH");
|
||||
if (home) {
|
||||
return home;
|
||||
}
|
||||
#if defined(__linux__)
|
||||
home = "/usr/local/cuda";
|
||||
if (std::filesystem::exists(home)) {
|
||||
return home;
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error(
|
||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||
}
|
||||
|
||||
// Get the cache directory for storing compiled results.
|
||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||
if (!std::filesystem::is_directory(path)) {
|
||||
std::error_code error;
|
||||
if (!std::filesystem::create_directories(path, error)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*result = path;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
if (error) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ptx_file(ptx_path, std::ios::binary);
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|.
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
}
|
||||
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||
inline bool version_lower_equal(Device& device, int major, int minor) {
|
||||
if (device.compute_capability_major() < major) {
|
||||
return true;
|
||||
} else if (device.compute_capability_major() == major) {
|
||||
return device.compute_capability_minor() <= minor;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Return whether NVRTC supports compiling to |device|'s SASS code.
|
||||
bool compiler_supports_device_sass(Device& device) {
|
||||
int nvrtc_major, nvrtc_minor;
|
||||
CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
if (nvrtc_major < 9) {
|
||||
return false;
|
||||
} else if (nvrtc_major == 9) {
|
||||
return version_lower_equal(device, 7, 2);
|
||||
} else if (nvrtc_major == 10) {
|
||||
return version_lower_equal(device, 7, 5);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor == 0) {
|
||||
return version_lower_equal(device, 8, 0);
|
||||
} else if (nvrtc_major == 11 && nvrtc_minor < 8) {
|
||||
return version_lower_equal(device, 8, 6);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
INCLUDE_PREFIX "unary_ops.cuh",
|
||||
INCLUDE_PREFIX "utils.cuh",
|
||||
};
|
||||
|
||||
#undef INCLUDE_PREFIX
|
||||
|
||||
constexpr const char* g_headers[] = {
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
jit_source_unary_ops,
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::filesystem::path cache_dir;
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
std::string include = fmt::format("--include-path={}/include", cuda_home());
|
||||
const char* args[] = {compute.c_str(), include.c_str()};
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, std::size(args), args);
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.data(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
}
|
||||
|
||||
// Load kernels.
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CU_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
CUfunction kernel = get_kernel(kernel_name);
|
||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
||||
int _, block_dim;
|
||||
CHECK_CU_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
Dims num_blocks{1, 1, 1};
|
||||
if (large) {
|
||||
num_blocks =
|
||||
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
|
||||
std::get<0>(num_blocks) =
|
||||
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
|
||||
} else {
|
||||
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
|
||||
}
|
||||
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
|
||||
}
|
||||
|
||||
void JitModule::launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims) {
|
||||
CHECK_CU_ERROR(cuLaunchKernel(
|
||||
kernel,
|
||||
std::get<0>(num_blocks),
|
||||
std::get<1>(num_blocks),
|
||||
std::get<2>(num_blocks),
|
||||
std::get<0>(block_dims),
|
||||
std::get<1>(block_dims),
|
||||
std::get<2>(block_dims),
|
||||
0,
|
||||
stream,
|
||||
args_.data(),
|
||||
nullptr));
|
||||
args_.clear();
|
||||
storage_.clear();
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void JitModule::append_ptr_arg(const void* v) {
|
||||
args_.push_back(const_cast<void*>(v));
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
static std::unordered_map<std::string, JitModule> map;
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
113
mlx/backend/cuda/jit_module.h
Normal file
113
mlx/backend/cuda/jit_module.h
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
|
||||
class JitModule {
|
||||
public:
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
|
||||
void append_arg(const array& a) {
|
||||
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(T val) {
|
||||
storage_.emplace_back(val);
|
||||
append_ptr_arg(&storage_.back());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_arg(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append_arg(std::monostate{});
|
||||
} else {
|
||||
append_ptr_arg(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim_arg(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::vector<T> copied(NDIM);
|
||||
std::copy(vec.begin(), vec.end(), copied.data());
|
||||
append_arg(std::move(copied));
|
||||
}
|
||||
|
||||
// Launch kernel with |kernel_name| that each thread works on
|
||||
// |work_per_thread| elements of |arr|.
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
const std::string& kernel_name,
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
|
||||
void launch_kernel(
|
||||
CUstream stream,
|
||||
CUfunction kernel,
|
||||
Dims num_blocks,
|
||||
Dims block_dims);
|
||||
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
void append_ptr_arg(const void* v);
|
||||
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
std::vector<const void*>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int64_t>>;
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -1,13 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||
// from backend/cuda/device/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
@ -244,8 +244,7 @@ void LayerNorm::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
array o = set_output(inputs[0]);
|
||||
const array& x = o.data_shared_ptr() ? o : out;
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/arange.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
@ -93,7 +92,6 @@ NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
@ -101,8 +99,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
|
343
mlx/backend/cuda/rms_norm.cu
Normal file
343
mlx/backend/cuda/rms_norm.cu
Normal file
@ -0,0 +1,343 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||
template <typename T, int BLOCK_DIM>
|
||||
struct BlockBroadcastReduce {
|
||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
||||
|
||||
cg::thread_block& block;
|
||||
TempStorage& temp;
|
||||
|
||||
template <typename Op>
|
||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
T x = cg::reduce(warp, input, op);
|
||||
if (warp.thread_rank() == 0) {
|
||||
temp[warp.meta_group_rank()] = x;
|
||||
}
|
||||
block.sync();
|
||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||
: init_value;
|
||||
return cg::reduce(warp, x, op);
|
||||
}
|
||||
|
||||
__device__ T Sum(const T& input) {
|
||||
return Reduce(input, cg::plus<T>{}, T{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm(
|
||||
const T* x,
|
||||
const T* w,
|
||||
T* out,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
out += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float normalizer = 0;
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
normalizer += t * t;
|
||||
}
|
||||
}
|
||||
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||
xn[i] = wn[i] * static_cast<T>(norm);
|
||||
}
|
||||
cub::StoreDirectBlocked(index, out, xn, axis_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void rms_norm_vjp(
|
||||
const T* x,
|
||||
const T* w,
|
||||
const T* g,
|
||||
T* gx,
|
||||
T* gw,
|
||||
float eps,
|
||||
int32_t axis_size,
|
||||
int64_t w_stride) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||
__shared__ union {
|
||||
typename BlockReduceF::TempStorage f;
|
||||
typename BlockReduceF2::TempStorage f2;
|
||||
} temp;
|
||||
|
||||
x += grid.block_rank() * axis_size;
|
||||
g += grid.block_rank() * axis_size;
|
||||
gx += grid.block_rank() * axis_size;
|
||||
gw += grid.block_rank() * axis_size;
|
||||
|
||||
// Normalizer.
|
||||
float2 factors = {};
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
T xn[N_READS];
|
||||
T wn[N_READS] = {};
|
||||
T gn[N_READS] = {};
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float t = static_cast<float>(xn[i]);
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
float wg = wi * gi;
|
||||
factors = plus_f2(factors, {wg * t, t * t});
|
||||
}
|
||||
}
|
||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
||||
float meangwx = factors.x / axis_size;
|
||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Outputs.
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||
T xn[N_READS];
|
||||
T wn[N_READS];
|
||||
T gn[N_READS];
|
||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = xn[i];
|
||||
float wi = wn[i];
|
||||
float gi = gn[i];
|
||||
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||
if constexpr (HAS_W) {
|
||||
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
cub::StoreDirectBlocked(index, gx, xn, axis_size);
|
||||
if constexpr (HAS_W) {
|
||||
cub::StoreDirectBlocked(index, gw, wn, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool RMSNorm::use_fallback(Stream s) {
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNorm::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous.
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("RMSNormVJP::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return {x, false};
|
||||
}
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return {x_copy, true};
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Check whether we had a weight.
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs.
|
||||
bool g_in_gx = false;
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
}
|
||||
|
||||
int32_t axis_size = x.shape().back();
|
||||
int32_t n_rows = x.data_size() / axis_size;
|
||||
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(g);
|
||||
encoder.set_output_array(gx);
|
||||
encoder.set_output_array(gw_temp);
|
||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
||||
using DataType = cuda_type_t<CTYPE>;
|
||||
constexpr int N_READS = 4;
|
||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
||||
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
gx.data<DataType>(),
|
||||
gw_temp.data<DataType>(),
|
||||
eps_,
|
||||
axis_size,
|
||||
w_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
if (has_w) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
@ -1,9 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/unary_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
if (dtype == float16) {
|
||||
return "__half";
|
||||
}
|
||||
if (dtype == bfloat16) {
|
||||
return "__nv_bfloat16";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (dtype == DTYPE) { \
|
||||
return #CPP_TYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -12,6 +12,8 @@ namespace cu {
|
||||
class Device;
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err);
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
@ -170,6 +171,41 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||
eval(inputs, out);
|
||||
|
@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu(
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
@ -225,6 +225,8 @@ struct MPIWrapper {
|
||||
return mpi_bfloat16_;
|
||||
case float64:
|
||||
return mpi_double_;
|
||||
default:
|
||||
throw std::runtime_error("Invalid type");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user