mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides * Convert SmallVector to tuple
This commit is contained in:
@@ -60,6 +60,16 @@ struct CommandEncoder {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
|
@@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
if (x.ndim() < 2) {
|
||||
if (x.strides()[0] == 1) {
|
||||
return x;
|
||||
}
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
|
||||
|
Reference in New Issue
Block a user