mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides * Convert SmallVector to tuple
This commit is contained in:
@@ -55,13 +55,13 @@ auto& conv_cache() {
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
@@ -78,7 +78,7 @@ auto nhwc_to_nchw(const array& x) {
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(shape, strides);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
@@ -298,10 +298,10 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation) {
|
||||
const SmallVector<int64_t>& stride,
|
||||
const SmallVector<int64_t>& padding_lo,
|
||||
const SmallVector<int64_t>& padding_hi,
|
||||
const SmallVector<int64_t>& dilation) {
|
||||
try {
|
||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
@@ -468,7 +468,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
|
||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||
// convolution, so we make a best guess and then try.
|
||||
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||
if (flip_) {
|
||||
// When weight is flipped, we assume it is backward input convolution.
|
||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||
|
||||
@@ -29,12 +29,12 @@ void append_indices_arg(
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
std::vector<const void*> indices(nidx);
|
||||
SmallVector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
args.append(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].shape().begin(),
|
||||
@@ -42,7 +42,7 @@ void append_indices_arg(
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
args.append(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].strides().begin(),
|
||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append<int32_t>(src.ndim());
|
||||
args.append_ndim(slice_sizes_);
|
||||
args.append(slice_size);
|
||||
args.append(axes_);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append_ndim(out.shape());
|
||||
args.append_ndim(out.strides());
|
||||
args.append<int32_t>(out.ndim());
|
||||
args.append(axes_);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
|
||||
@@ -40,19 +40,14 @@ struct KernelArgs {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append(std::monostate{});
|
||||
} else {
|
||||
append_ptr(vec.data());
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
void append(SmallVector<T> vec) {
|
||||
storage_.emplace_back(std::move(vec));
|
||||
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim(std::vector<T> vec) {
|
||||
void append_ndim(SmallVector<T> vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
@@ -76,9 +71,9 @@ struct KernelArgs {
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
std::vector<const void*>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int64_t>>;
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
|
||||
@@ -28,15 +28,20 @@ inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
if (x.ndim() < 2) {
|
||||
if (x.strides()[0] == 1) {
|
||||
return x;
|
||||
}
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user