mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
parent
a4d290adb9
commit
7c441600fe
@ -162,12 +162,23 @@ void array::copy_shared_buffer(const array& other) {
|
|||||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::move_shared_buffer(array other) {
|
void array::move_shared_buffer(
|
||||||
|
array other,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset /* = 0 */) {
|
||||||
array_desc_->data = std::move(other.array_desc_->data);
|
array_desc_->data = std::move(other.array_desc_->data);
|
||||||
array_desc_->strides = other.strides();
|
array_desc_->strides = strides;
|
||||||
array_desc_->flags = other.flags();
|
array_desc_->flags = flags;
|
||||||
array_desc_->data_size = other.data_size();
|
array_desc_->data_size = data_size;
|
||||||
array_desc_->data_ptr = other.array_desc_->data_ptr;
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||||
|
array_desc_->data_ptr = static_cast<void*>(
|
||||||
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::move_shared_buffer(array other) {
|
||||||
|
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||||
|
@ -339,6 +339,13 @@ class array {
|
|||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
|
void move_shared_buffer(
|
||||||
|
array other,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset = 0);
|
||||||
|
|
||||||
void move_shared_buffer(array other);
|
void move_shared_buffer(array other);
|
||||||
|
|
||||||
void overwrite_descriptor(const array& other) {
|
void overwrite_descriptor(const array& other) {
|
||||||
|
@ -385,7 +385,9 @@ void Compiled::eval_cpu(
|
|||||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o].copy_shared_buffer(
|
||||||
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
|
@ -329,7 +329,9 @@ void Compiled::eval_gpu(
|
|||||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
outputs[o++].move_shared_buffer(in);
|
outputs[o].move_shared_buffer(
|
||||||
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
|
@ -13,12 +13,10 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
|
|||||||
device float* O_partials [[buffer(5)]],
|
device float* O_partials [[buffer(5)]],
|
||||||
device float* p_lse [[buffer(6)]],
|
device float* p_lse [[buffer(6)]],
|
||||||
device float* p_maxes [[buffer(7)]],
|
device float* p_maxes [[buffer(7)]],
|
||||||
|
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
|
||||||
threadgroup T threadgroup_block[32768 / sizeof(T)];
|
|
||||||
|
|
||||||
constexpr const size_t DK = 128;
|
constexpr const size_t DK = 128;
|
||||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||||
@ -358,6 +356,7 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
|
|||||||
device float* O_partials [[buffer(5)]], \
|
device float* O_partials [[buffer(5)]], \
|
||||||
device float* p_lse [[buffer(6)]], \
|
device float* p_lse [[buffer(6)]], \
|
||||||
device float* p_maxes [[buffer(7)]], \
|
device float* p_maxes [[buffer(7)]], \
|
||||||
|
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]]);
|
uint3 tid [[threadgroup_position_in_grid]]);
|
||||||
|
@ -97,6 +97,8 @@ void sdpa_metal(
|
|||||||
set_array_buffer(compute_encoder, p_lse, 6);
|
set_array_buffer(compute_encoder, p_lse, 6);
|
||||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||||
|
|
||||||
|
constexpr const uint tgroupMemorySize = 32768;
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -439,7 +439,8 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
auto& src = parents->second[j].first;
|
auto& src = parents->second[j].first;
|
||||||
auto& dst = parents->second[i].first;
|
auto& dst = parents->second[i].first;
|
||||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
||||||
|
output_set.find(src.id()) == output_set.end()) {
|
||||||
merge(dst, src, parents_map);
|
merge(dst, src, parents_map);
|
||||||
mask[j] = true;
|
mask[j] = true;
|
||||||
}
|
}
|
||||||
@ -456,7 +457,6 @@ void compile_simplify(
|
|||||||
return output_set.find(a.id()) == output_set.end();
|
return output_set.find(a.id()) == output_set.end();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool discard = maybe_merge_parents(arr);
|
bool discard = maybe_merge_parents(arr);
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
discard &= maybe_merge_parents(s);
|
discard &= maybe_merge_parents(s);
|
||||||
|
@ -605,6 +605,14 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
out = fun(mx.array(0.0), y=MyClass())
|
out = fun(mx.array(0.0), y=MyClass())
|
||||||
|
|
||||||
|
def test_compile_create_list(self):
|
||||||
|
@mx.compile
|
||||||
|
def fun():
|
||||||
|
return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))]
|
||||||
|
|
||||||
|
out = fun()
|
||||||
|
mx.eval(out)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -703,3 +703,18 @@ TEST_CASE("test shapeless compile") {
|
|||||||
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto compile_broadcast_add(const std::vector<array>& inputs) {
|
||||||
|
auto b = zeros({8, 8});
|
||||||
|
return std::vector<array>{inputs[0] + b};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile strides") {
|
||||||
|
{
|
||||||
|
auto cfun = compile(compile_broadcast_add);
|
||||||
|
auto a = zeros({1, 8, 8});
|
||||||
|
auto out = cfun({a})[0];
|
||||||
|
eval(out);
|
||||||
|
CHECK_EQ(out.strides().size(), 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user