mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix tests on large arrays
This commit is contained in:
@@ -20,7 +20,7 @@ namespace cg = cooperative_groups;
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -44,7 +44,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -70,7 +70,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -96,7 +96,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -52,7 +52,7 @@ template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -85,7 +85,7 @@ template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -118,7 +118,7 @@ template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace cg = cooperative_groups;
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -37,7 +37,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ template <typename Op, typename T, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace cg = cooperative_groups;
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
int remaining = size - index * N_READS;
|
||||
IdxT remaining = size - index * N_READS;
|
||||
if (remaining <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user