mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 09:29:26 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
@@ -162,12 +162,23 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = other.strides();
|
||||
array_desc_->flags = other.flags();
|
||||
array_desc_->data_size = other.data_size();
|
||||
array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
|
||||
Reference in New Issue
Block a user