mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -64,8 +64,8 @@ void explicit_gemm_conv_ND_gpu(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
Shape wt_reshape{implicit_K, implicit_N};
|
||||
Strides wt_restride{1, implicit_K};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
@@ -147,10 +147,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
|
||||
Reference in New Issue
Block a user