mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
qmm
...
56cc858af9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 |
@@ -377,4 +377,10 @@ void copy_cpu_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -30,4 +30,7 @@ void copy_cpu_inplace(
|
|||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return {contiguous_copy_cpu(arr, stream), true};
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
|
||||||
return {arr_copy, true};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_cpu(in, s);
|
||||||
copy_cpu(in, arr_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy_cpu(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return std::make_tuple(true, sty, arr, false);
|
return std::make_tuple(true, sty, arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
|
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return std::make_pair(arr, false);
|
return std::make_pair(arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
return std::make_pair(arr_copy, true);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Ensure contiguity
|
// Ensure contiguity
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_cpu(in, stream());
|
||||||
copy_cpu(in, arr_copy, CopyType::General, stream());
|
encoder.add_temporary(in);
|
||||||
in = arr_copy;
|
|
||||||
encoder.add_temporary(arr_copy);
|
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy_cpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
|
|||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
|
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
|
||||||
best = op.reduce_many(best, vals, tid * N_READS);
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
|||||||
to[offset] = vec;
|
to[offset] = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper for accessing strided data.
|
||||||
|
template <typename T>
|
||||||
|
struct StridedIterator {
|
||||||
|
T it;
|
||||||
|
int64_t stride;
|
||||||
|
|
||||||
|
__host__ __device__ StridedIterator(T it, int64_t stride)
|
||||||
|
: it(it), stride(stride) {}
|
||||||
|
|
||||||
|
__host__ __device__ auto operator[](int i) const {
|
||||||
|
return it[i * stride];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -1,121 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <cuda/std/utility>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// Iterating non-contiguous array.
|
|
||||||
template <typename Iterator, typename IdxT = int64_t>
|
|
||||||
class general_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides)
|
|
||||||
: super_t(it),
|
|
||||||
index_(index),
|
|
||||||
ndim_(ndim),
|
|
||||||
shape_(cuda::std::move(shape)),
|
|
||||||
strides_(cuda::std::move(strides)) {}
|
|
||||||
|
|
||||||
__host__ __device__ IdxT index() const {
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Shape& shape() const {
|
|
||||||
return shape_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ const Strides& strides() const {
|
|
||||||
return strides_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const general_iterator& other) const {
|
|
||||||
return this->base() == other.base() && this->index() == other.index();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->index_ += n;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->index_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->index_ -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const general_iterator& other) const {
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
this->base() == other.base(),
|
|
||||||
"Underlying iterator must point to same base iterator");
|
|
||||||
return other.index() - this->index();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The dereference is device-only to avoid accidental running in host.
|
|
||||||
__device__ typename super_t::reference dereference() const {
|
|
||||||
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
|
|
||||||
return *(this->base() + offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
IdxT index_;
|
|
||||||
int ndim_;
|
|
||||||
Shape shape_;
|
|
||||||
Strides strides_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
__host__ __device__ auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
IdxT index,
|
|
||||||
int ndim,
|
|
||||||
Shape shape,
|
|
||||||
Strides strides) {
|
|
||||||
return general_iterator<Iterator, IdxT>(
|
|
||||||
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterator(
|
|
||||||
Iterator it,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
return make_general_iterator<IdxT>(
|
|
||||||
it, 0, shape.size(), const_param(shape), const_param(strides));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IdxT, typename Iterator>
|
|
||||||
auto make_general_iterators(
|
|
||||||
Iterator it,
|
|
||||||
IdxT size,
|
|
||||||
const std::vector<int32_t>& shape,
|
|
||||||
const std::vector<int64_t>& strides) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
auto shape_arg = const_param(shape);
|
|
||||||
auto strides_arg = const_param(strides);
|
|
||||||
return std::make_pair(
|
|
||||||
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
|
|
||||||
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <thrust/iterator/iterator_adaptor.h>
|
|
||||||
#include <thrust/iterator/iterator_facade.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// RandomAccessIterator for strided access to array entries.
|
|
||||||
template <typename Iterator, typename Stride = int64_t>
|
|
||||||
class strided_iterator
|
|
||||||
: public thrust::
|
|
||||||
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
|
|
||||||
public:
|
|
||||||
using super_t =
|
|
||||||
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
|
|
||||||
|
|
||||||
using reference = typename super_t::reference;
|
|
||||||
using difference_type = typename super_t::difference_type;
|
|
||||||
|
|
||||||
__host__ __device__ strided_iterator(Iterator it, Stride stride)
|
|
||||||
: super_t(it), stride_(stride) {}
|
|
||||||
|
|
||||||
__host__ __device__ Stride stride() const {
|
|
||||||
return stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class thrust::iterator_core_access;
|
|
||||||
|
|
||||||
__host__ __device__ bool equal(const strided_iterator& other) const {
|
|
||||||
return this->base() == other.base();
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void advance(difference_type n) {
|
|
||||||
this->base_reference() += n * stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void increment() {
|
|
||||||
this->base_reference() += stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void decrement() {
|
|
||||||
this->base_reference() -= stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ difference_type
|
|
||||||
distance_to(const strided_iterator& other) const {
|
|
||||||
const difference_type dist = other.base() - this->base();
|
|
||||||
_CCCL_ASSERT(
|
|
||||||
dist % stride() == 0,
|
|
||||||
"Underlying iterator difference must be divisible by the stride");
|
|
||||||
return dist / stride();
|
|
||||||
}
|
|
||||||
|
|
||||||
Stride stride_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -105,8 +104,8 @@ __global__ void layer_norm(
|
|||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
T bn[N_READS];
|
T bn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
|
||||||
@@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]) - mean;
|
float t = static_cast<float>(xn[i]) - mean;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
#include "mlx/backend/cuda/reduce/reduce.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@@ -89,7 +88,7 @@ __global__ void rms_norm(
|
|||||||
T xn[N_READS];
|
T xn[N_READS];
|
||||||
T wn[N_READS];
|
T wn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
float norm = static_cast<float>(xn[i]) * normalizer;
|
float norm = static_cast<float>(xn[i]) * normalizer;
|
||||||
xn[i] = wn[i] * static_cast<T>(norm);
|
xn[i] = wn[i] * static_cast<T>(norm);
|
||||||
@@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
|
|||||||
auto index = r * BLOCK_DIM + block.thread_rank();
|
auto index = r * BLOCK_DIM + block.thread_rank();
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float t = static_cast<float>(xn[i]);
|
float t = static_cast<float>(xn[i]);
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
@@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
|
|||||||
T gn[N_READS];
|
T gn[N_READS];
|
||||||
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
cub::LoadDirectBlocked(index, x, xn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
cub::LoadDirectBlocked(index, g, gn, axis_size);
|
||||||
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
|
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
float xi = xn[i];
|
float xi = xn[i];
|
||||||
float wi = wn[i];
|
float wi = wn[i];
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/common/unary.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|||||||
Reference in New Issue
Block a user