mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 13:01:17 +08:00
Fix compilation with CUDA 11 (#2331)
This commit is contained in:
parent
4a9b29a875
commit
2ca533b279
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
@ -264,19 +264,26 @@ void CommandEncoder::commit() {
|
|||||||
graph_key_ += std::to_string(graph_node_count_);
|
graph_key_ += std::to_string(graph_node_count_);
|
||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
graph_key_ += std::to_string(empty_node_count_);
|
graph_key_ += std::to_string(empty_node_count_);
|
||||||
auto [it, _] = graph_cache_.emplace(graph_key_, nullptr);
|
|
||||||
auto& graph_exec = it->second;
|
|
||||||
|
|
||||||
if (graph_exec != NULL) {
|
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
|
||||||
cudaGraphExecUpdateResultInfo update_result;
|
|
||||||
cudaGraphExecUpdate(graph_exec, graph_, &update_result);
|
if (graph_exec != nullptr) {
|
||||||
if (update_result.result != cudaGraphExecUpdateSuccess) {
|
cudaGraphExecUpdateResult update_result;
|
||||||
cudaGetLastError();
|
#if CUDART_VERSION >= 12000
|
||||||
|
cudaGraphExecUpdateResultInfo info;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||||
|
update_result = info.result;
|
||||||
|
#else
|
||||||
|
cudaGraphNode_t error_node;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||||
|
cudaGetLastError(); // reset error
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
||||||
graph_exec = NULL;
|
graph_exec = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (graph_exec == NULL) {
|
if (graph_exec == nullptr) {
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
#include <thrust/iterator/transform_iterator.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@ -17,6 +19,26 @@ struct CastOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Castings between complex and boolean.
|
||||||
|
// TODO: Should make a custom complex type.
|
||||||
|
template <>
|
||||||
|
struct CastOp<cuComplex, bool> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ bool operator()(cuComplex x) {
|
||||||
|
return x.x != 0 && x.y != 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CastOp<bool, cuComplex> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ cuComplex operator()(bool x) {
|
||||||
|
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converting a complex number to real number discards the imaginary part.
|
// Converting a complex number to real number discards the imaginary part.
|
||||||
template <typename DstT>
|
template <typename DstT>
|
||||||
struct CastOp<
|
struct CastOp<
|
||||||
@ -45,6 +67,7 @@ struct CastOp<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Do nothing when no casting is needed.
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
struct CastOp<
|
struct CastOp<
|
||||||
SrcT,
|
SrcT,
|
||||||
@ -57,9 +80,53 @@ struct CastOp<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// In CUDA 11 the half types do not define conversions between some types,
|
||||||
|
// provide fallbacks here.
|
||||||
|
#if CUDART_VERSION < 12000
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
||||||
|
!cuda::std::is_same_v<SrcT, cuComplex> &&
|
||||||
|
(cuda::std::is_same_v<DstT, __half> ||
|
||||||
|
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ DstT operator()(SrcT x) {
|
||||||
|
return DstT(static_cast<float>(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<
|
||||||
|
!cuda::std::is_convertible_v<SrcT, DstT> &&
|
||||||
|
!cuda::std::is_same_v<DstT, cuComplex> &&
|
||||||
|
!cuda::std::is_same_v<DstT, __half> &&
|
||||||
|
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
||||||
|
(cuda::std::is_same_v<SrcT, __half> ||
|
||||||
|
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ DstT operator()(SrcT x) {
|
||||||
|
return DstT(static_cast<float>(x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // CUDART_VERSION < 12000
|
||||||
|
|
||||||
|
// Helper to deduce the SrcT.
|
||||||
|
template <typename DstT, typename SrcT>
|
||||||
|
inline __host__ __device__ auto cast_to(SrcT x) {
|
||||||
|
return CastOp<SrcT, DstT>{}(x);
|
||||||
|
}
|
||||||
|
|
||||||
// Return an iterator that cast the value to DstT using CastOp.
|
// Return an iterator that cast the value to DstT using CastOp.
|
||||||
template <typename DstT, typename Iterator>
|
template <typename DstT, typename Iterator>
|
||||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||||
return it;
|
return it;
|
||||||
|
@ -99,20 +99,20 @@ struct Limits<
|
|||||||
return cuda::std::numeric_limits<T>::infinity();
|
return cuda::std::numeric_limits<T>::infinity();
|
||||||
}
|
}
|
||||||
static constexpr __host__ __device__ T min() {
|
static constexpr __host__ __device__ T min() {
|
||||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||||
return -cuda::std::numeric_limits<T>::infinity();
|
|
||||||
#else
|
|
||||||
return -cuda::std::numeric_limits<float>::infinity();
|
return -cuda::std::numeric_limits<float>::infinity();
|
||||||
|
#else
|
||||||
|
return -cuda::std::numeric_limits<T>::infinity();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
static constexpr __host__ __device__ T finite_max() {
|
static constexpr __host__ __device__ T finite_max() {
|
||||||
return cuda::std::numeric_limits<T>::max();
|
return cuda::std::numeric_limits<T>::max();
|
||||||
}
|
}
|
||||||
static constexpr __host__ __device__ T finite_min() {
|
static constexpr __host__ __device__ T finite_min() {
|
||||||
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
|
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||||
return cuda::std::numeric_limits<T>::lowest();
|
|
||||||
#else
|
|
||||||
return cuda::std::numeric_limits<float>::lowest();
|
return cuda::std::numeric_limits<float>::lowest();
|
||||||
|
#else
|
||||||
|
return cuda::std::numeric_limits<T>::lowest();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
|
|||||||
for (; i + block.size() * N <= check; i += block.size() * N) {
|
for (; i + block.size() * N <= check; i += block.size() * N) {
|
||||||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
|
accs[0] = op(accs[0], cast_to<U>(vals[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i < check) {
|
if (i < check) {
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
|
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
|
accs[0] = op(accs[0], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
in + loop.location(),
|
in + loop.location(),
|
||||||
vals,
|
vals,
|
||||||
args.reduction_stride - tile_x * BN,
|
args.reduction_stride - tile_x * BN,
|
||||||
__cast<T, U>(ReduceInit<Op, T>::value()));
|
cast_to<T>(ReduceInit<Op, T>::value()));
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
|
totals[i] = op(totals[i], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
|
||||||
|
|
||||||
@ -40,15 +42,15 @@ struct Sum {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
|
||||||
atomicAdd(x, y);
|
atomic_add(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void atomic_update(int* x, int y) {
|
__device__ void atomic_update(int* x, int y) {
|
||||||
atomicAdd(x, y);
|
atomic_add(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void atomic_update(float* x, float y) {
|
__device__ void atomic_update(float* x, float y) {
|
||||||
atomicAdd(x, y);
|
atomic_add(x, y);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -152,7 +154,7 @@ struct ReduceInit<Sum, T> {
|
|||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return T{0, 0};
|
return T{0, 0};
|
||||||
} else {
|
} else {
|
||||||
return typename ReduceResult<Sum, T>::type{0};
|
return cast_to<typename ReduceResult<Sum, T>::type>(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -163,7 +165,7 @@ struct ReduceInit<Prod, T> {
|
|||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return T{1, 0};
|
return T{1, 0};
|
||||||
} else {
|
} else {
|
||||||
return typename ReduceResult<Prod, T>::type{1};
|
return cast_to<typename ReduceResult<Prod, T>::type>(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -55,22 +55,6 @@ __device__ void atomic_reduce(T* x, T y) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Should make a custom complex type
|
|
||||||
template <typename U, typename T>
|
|
||||||
inline __device__ U __cast(T x) {
|
|
||||||
return static_cast<U>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
|
|
||||||
return x.x != 0 && x.y != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
|
|
||||||
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int N, typename Block, typename Warp, typename Op>
|
template <typename T, int N, typename Block, typename Warp, typename Op>
|
||||||
inline __device__ void
|
inline __device__ void
|
||||||
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@ -113,7 +112,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
in + k * size + r * (block.size() * N),
|
in + k * size + r * (block.size() * N),
|
||||||
vals[k]);
|
vals[k]);
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -125,7 +124,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
in + k * size + r * (block.size() * N),
|
in + k * size + r * (block.size() * N),
|
||||||
vals[k]);
|
vals[k]);
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -138,9 +137,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
|||||||
in + k * size + final_offset,
|
in + k * size + final_offset,
|
||||||
vals[k],
|
vals[k],
|
||||||
size,
|
size,
|
||||||
__cast<T, U>(init));
|
cast_to<T>(init));
|
||||||
for (int j = 0; j < N; j++) {
|
for (int j = 0; j < N; j++) {
|
||||||
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
|
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -199,7 +198,7 @@ __global__ void row_reduce_looped(
|
|||||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||||
vals);
|
vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (final_offset < args.row_size) {
|
if (final_offset < args.row_size) {
|
||||||
@ -209,9 +208,9 @@ __global__ void row_reduce_looped(
|
|||||||
in + loop.location() + final_offset,
|
in + loop.location() + final_offset,
|
||||||
vals,
|
vals,
|
||||||
args.row_size - final_offset,
|
args.row_size - final_offset,
|
||||||
__cast<T, U>(init));
|
cast_to<T>(init));
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
total[0] = op(total[0], __cast<U, T>(vals[i]));
|
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: Maybe block.sync() here?
|
// TODO: Maybe block.sync() here?
|
||||||
|
@ -74,7 +74,7 @@ __global__ void rms_norm(
|
|||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
T xn[N_READS];
|
T xn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
normalizer += t * t;
|
normalizer += t * t;
|
||||||
@ -130,7 +130,7 @@ __global__ void rms_norm_vjp(
|
|||||||
T wn[N_READS] = {};
|
T wn[N_READS] = {};
|
||||||
T gn[N_READS] = {};
|
T gn[N_READS] = {};
|
||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
|
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
|
|||||||
// Thread reduce.
|
// Thread reduce.
|
||||||
AccT prevmax;
|
AccT prevmax;
|
||||||
AccT maxval = Limits<AccT>::finite_min();
|
AccT maxval = Limits<AccT>::finite_min();
|
||||||
AccT normalizer = 0;
|
AccT normalizer = cast_to<AccT>(0);
|
||||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
|
||||||
AccT vals[N_READS];
|
AccT vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
|
Loading…
Reference in New Issue
Block a user