mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix large ops (#1620)
This commit is contained in:
@@ -36,13 +36,13 @@ template <typename T, typename U>
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
@@ -97,15 +97,15 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
|
||||
auto src_idx = elem_to_loc_1<int64_t, IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<int64_t, IdxT>(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user