Compare commits

...

4 Commits

Author SHA1 Message Date
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
14 changed files with 166 additions and 252 deletions

View File

@@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -196,9 +196,6 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));

View File

@@ -157,10 +157,12 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
@@ -175,8 +177,8 @@ inline void build_kernel(
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
}
}
@@ -186,10 +188,8 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
// Add output size
if (contiguous) {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -288,17 +288,8 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
@@ -306,9 +297,6 @@ void Compiled::eval_cpu(
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
}
// Get the kernel name from the lib
@@ -343,16 +331,20 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (!contiguous) {
args.push_back((void*)shape.data());
} else {
if (contiguous) {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable { fun(args.data()); });
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
}
} // namespace mlx::core

View File

@@ -336,6 +336,42 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
}
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
const array& x,
int groups,
int group_dim,
int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
}
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
@@ -345,13 +381,14 @@ std::tuple<array, array, array> prepare_args(
array in,
array wt,
array out,
int groups,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
wt = group_transpose(wt, groups, 0, 0, -1, s);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
in = group_transpose(in, groups, -1, 0, -1, s);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
@@ -457,7 +494,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
@@ -490,7 +528,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,

View File

@@ -1,6 +1,5 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"

View File

@@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {
start_axis += ndim;
}
if (end_axis < 0) {
end_axis += ndim;
}
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim - 1, end_axis);
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
}
array reshape_in_eval(const array& x, Shape shape, Stream s) {
array out(std::move(shape), x.dtype(), nullptr, {});
reshape_gpu(x, out, s);
return out;
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
// Copy data from |in| and transpose to |out|'s shape.
void reshape_gpu(const array& in, array& out, Stream s);
// Like the normal ops but safe to call in eval_gpu.
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
array reshape_in_eval(const array& x, Shape shape, Stream s);
array swapaxes_in_eval(const array& x, int axis1, int axis2);
} // namespace mlx::core

View File

@@ -20,29 +20,6 @@
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
@@ -124,7 +101,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -150,7 +127,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void Split::eval_gpu(
@@ -224,7 +201,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,22 +60,12 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}

View File

@@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
out += tid.z * params.out_strides[0];
in += tid.z * params.in_strides[0];
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -440,6 +440,7 @@ class SmallVector {
end_ = begin_;
}
private:
// Grows the backing store by a factor of two, and at least to {min_capacity}.
// TODO: Move to private after removing external code using this method.
MLX_NOINLINE void grow(size_t min_capacity = 0) {
@@ -469,7 +470,6 @@ class SmallVector {
end_of_storage_ = new_storage + new_capacity;
}
private:
MLX_NOINLINE void free_storage() {
std::destroy_n(begin_, end_ - begin_);
if (is_big()) {
@@ -519,6 +519,18 @@ class SmallVector {
std::is_trivially_destructible<T>::value;
};
template <typename>
struct is_vector : std::false_type {};
template <typename T, size_t Size, typename Allocator>
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};
template <typename T, typename Allocator>
struct is_vector<std::vector<T, Allocator>> : std::true_type {};
template <typename Vec>
inline constexpr bool is_vector_v = is_vector<Vec>::value;
#undef MLX_HAS_BUILTIN
#undef MLX_HAS_ATTRIBUTE
#undef MLX_LIKELY

View File

@@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
// TODO: Code is duplicated but they should be deleted soon.
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
namespace env {
int get_var(const char* name, int default_value) {

View File

@@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}
@@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
os << "(";
for (auto it = v.begin(); it != v.end(); ++it) {
os << *it;
if (it != std::prev(v.end())) {
os << ",";
}
}
os << ")";
return os;
}
inline bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}

View File

@@ -17,7 +17,6 @@ cuda_skip = {
"TestConv.test_1d_conv_with_2d",
"TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad",
"TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general",