Use SmallVector for shapes and strides (#2454)

* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
This commit is contained in:
Cheng
2025-08-05 09:41:03 +09:00
committed by GitHub
parent 7d86a5c108
commit 828c5f1137
30 changed files with 738 additions and 102 deletions

View File

@@ -60,6 +60,16 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);

View File

@@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix(
const array& x,
metal::Device& d,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {