mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops * fix bug * fix all of copy
This commit is contained in:
parent
c52d1600f0
commit
40b6d67333
@ -21,10 +21,43 @@ namespace mlx::core {
|
|||||||
|
|
||||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||||
|
|
||||||
|
std::string get_kernel_name(
|
||||||
|
BinaryOpType bopt,
|
||||||
|
const std::string& op,
|
||||||
|
const array& a,
|
||||||
|
bool use_2d,
|
||||||
|
int ndim) {
|
||||||
|
std::ostringstream kname;
|
||||||
|
switch (bopt) {
|
||||||
|
case BinaryOpType::ScalarScalar:
|
||||||
|
kname << "ss";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::ScalarVector:
|
||||||
|
kname << (use_2d ? "sv2" : "sv");
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorScalar:
|
||||||
|
kname << (use_2d ? "vs2" : "vs");
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorVector:
|
||||||
|
kname << (use_2d ? "vv2" : "vv");
|
||||||
|
break;
|
||||||
|
case BinaryOpType::General:
|
||||||
|
kname << "g";
|
||||||
|
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
kname << ndim;
|
||||||
|
} else {
|
||||||
|
kname << "n";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << op << type_to_name(a);
|
||||||
|
return kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -41,35 +74,8 @@ void binary_op_gpu_inplace(
|
|||||||
auto& strides_b = strides[1];
|
auto& strides_b = strides[1];
|
||||||
auto& strides_out = strides[2];
|
auto& strides_out = strides[2];
|
||||||
|
|
||||||
std::string kernel_name;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
{
|
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||||
std::ostringstream kname;
|
|
||||||
switch (bopt) {
|
|
||||||
case BinaryOpType::ScalarScalar:
|
|
||||||
kname << "ss";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::ScalarVector:
|
|
||||||
kname << "sv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorScalar:
|
|
||||||
kname << "vs";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorVector:
|
|
||||||
kname << "vv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::General:
|
|
||||||
kname << "g";
|
|
||||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
kname << shape.size();
|
|
||||||
} else {
|
|
||||||
kname << "n";
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
kname << op << type_to_name(a);
|
|
||||||
kernel_name = kname.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
auto kernel =
|
auto kernel =
|
||||||
@ -117,9 +123,11 @@ void binary_op_gpu_inplace(
|
|||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = use_2d
|
||||||
|
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
@ -132,7 +140,7 @@ void binary_op_gpu_inplace(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@ -146,7 +154,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string op) {
|
const std::string& op) {
|
||||||
auto& s = outputs[0].primitive().stream();
|
auto& s = outputs[0].primitive().stream();
|
||||||
binary_op_gpu(inputs, outputs, op, s);
|
binary_op_gpu(inputs, outputs, op, s);
|
||||||
}
|
}
|
||||||
@ -154,7 +162,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
@ -169,35 +177,8 @@ void binary_op_gpu_inplace(
|
|||||||
auto& strides_b = strides[1];
|
auto& strides_b = strides[1];
|
||||||
auto& strides_out = strides[2];
|
auto& strides_out = strides[2];
|
||||||
|
|
||||||
std::string kernel_name;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
{
|
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||||
std::ostringstream kname;
|
|
||||||
switch (bopt) {
|
|
||||||
case BinaryOpType::ScalarScalar:
|
|
||||||
kname << "ss";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::ScalarVector:
|
|
||||||
kname << "sv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorScalar:
|
|
||||||
kname << "vs";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorVector:
|
|
||||||
kname << "vv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::General:
|
|
||||||
kname << "g";
|
|
||||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
kname << shape.size();
|
|
||||||
} else {
|
|
||||||
kname << "n";
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
kname << op << type_to_name(a);
|
|
||||||
kernel_name = kname.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||||
@ -237,10 +218,11 @@ void binary_op_gpu_inplace(
|
|||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
// Launch a 1D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads =
|
|
||||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
@ -253,7 +235,7 @@ void binary_op_gpu_inplace(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
@ -266,7 +248,7 @@ void binary_op_gpu(
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op) {
|
const std::string& op) {
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
binary_op_gpu(inputs, out, op, s);
|
binary_op_gpu(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
@ -9,25 +9,25 @@ namespace mlx::core {
|
|||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
void binary_op_gpu_inplace(
|
void binary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string op,
|
const std::string& op,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -64,16 +64,17 @@ void copy_gpu_inplace(
|
|||||||
auto& strides_in_ = strides[0];
|
auto& strides_in_ = strides[0];
|
||||||
auto& strides_out_ = strides[1];
|
auto& strides_out_ = strides[1];
|
||||||
|
|
||||||
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
kname << "s";
|
kname << (use_2d ? "s2" : "s");
|
||||||
break;
|
break;
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
kname << "v";
|
kname << (use_2d ? "v2" : "v");
|
||||||
break;
|
break;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
kname << "g";
|
kname << "g";
|
||||||
@ -139,7 +140,8 @@ void copy_gpu_inplace(
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
} else {
|
} else {
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
|
@ -36,6 +36,39 @@ template <typename T, typename U, typename Op>
|
|||||||
c[index] = Op()(a[index], b[index]);
|
c[index] = Op()(a[index], b[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_sv2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
c[offset] = Op()(a[0], b[offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vs2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
c[offset] = Op()(a[offset], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vv2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
c[offset] = Op()(a[offset], b[offset]);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void binary_g_nd1(
|
[[kernel]] void binary_g_nd1(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||||
|
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
|
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
|
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
@ -48,6 +48,48 @@ template <typename T, typename U, typename Op>
|
|||||||
d[index] = out[1];
|
d[index] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_sv2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
auto out = Op()(a[0], b[offset]);
|
||||||
|
c[offset] = out[0];
|
||||||
|
d[offset] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vs2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
auto out = Op()(a[offset], b[0]);
|
||||||
|
c[offset] = out[0];
|
||||||
|
d[offset] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vv2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
auto out = Op()(a[offset], b[offset]);
|
||||||
|
c[offset] = out[0];
|
||||||
|
d[offset] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void binary_g_nd1(
|
[[kernel]] void binary_g_nd1(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
|
@ -12,6 +12,9 @@
|
|||||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||||
|
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
||||||
|
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
||||||
|
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
||||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
|
@ -16,6 +16,26 @@ template <typename T, typename U>
|
|||||||
dst[index] = static_cast<U>(src[index]);
|
dst[index] = static_cast<U>(src[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_s2(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
dst[offset] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_v2(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
dst[offset] = static_cast<U>(src[offset]);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g_nd1(
|
[[kernel]] void copy_g_nd1(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
|
@ -5,95 +5,23 @@
|
|||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/copy.h"
|
#include "mlx/backend/metal/kernels/copy.h"
|
||||||
|
|
||||||
#define instantiate_copy(name, itype, otype, ctype) \
|
|
||||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
|
||||||
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
|
|
||||||
copy_g_nd<itype, otype, dims>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]], \
|
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
|
||||||
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
|
|
||||||
copy_gg_nd<itype, otype, dims>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
|
||||||
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
|
||||||
uint index [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
|
||||||
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]], \
|
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
|
||||||
template [[host_name("gg1_" name )]] [[kernel]] void \
|
|
||||||
copy_gg_nd1<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
|
||||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
|
||||||
uint index [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("gg2_" name)]] [[kernel]] void \
|
|
||||||
copy_gg_nd2<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
|
||||||
uint2 index [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("gg3_" name)]] [[kernel]] void \
|
|
||||||
copy_gg_nd3<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]]); \
|
|
||||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
|
||||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
|
||||||
|
|
||||||
#define instantiate_copy_g(name, itype, otype) \
|
|
||||||
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
constant const int& ndim [[buffer(5)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]], \
|
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
|
||||||
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
|
||||||
device const itype* src [[buffer(0)]], \
|
|
||||||
device otype* dst [[buffer(1)]], \
|
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
|
||||||
constant const int& ndim [[buffer(5)]], \
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
instantiate_copy("s_copy" #tname, itype, otype, s) \
|
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||||
instantiate_copy("v_copy" #tname, itype, otype, v) \
|
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||||
instantiate_copy_g("copy" #tname, itype, otype) \
|
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||||
instantiate_copy_g_nd("copy" #tname, itype, otype)
|
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||||
|
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||||
|
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||||
|
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||||
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||||
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
|
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
|
||||||
|
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
|
||||||
|
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
|
||||||
|
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
|
||||||
|
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||||
|
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
||||||
|
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
|
@ -34,7 +34,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
threadgroup float local_mean[1];
|
threadgroup float local_mean[1];
|
||||||
threadgroup float local_normalizer[1];
|
threadgroup float local_normalizer[1];
|
||||||
|
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
b += b_stride * lid * N_READS;
|
b += b_stride * lid * N_READS;
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer = local_normalizer[0];
|
float normalizer = local_normalizer[0];
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * axis_size + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
@ -131,7 +131,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
threadgroup float local_mean[1];
|
threadgroup float local_mean[1];
|
||||||
threadgroup float local_normalizer[1];
|
threadgroup float local_normalizer[1];
|
||||||
|
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
b += b_stride * lid * N_READS;
|
b += b_stride * lid * N_READS;
|
||||||
|
|
||||||
@ -188,7 +188,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer = local_normalizer[0];
|
float normalizer = local_normalizer[0];
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * axis_size + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -223,8 +223,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * axis_size + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the computation and accumulators
|
// Allocate registers for the computation and accumulators
|
||||||
@ -321,8 +321,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer2 = normalizer * normalizer;
|
float normalizer2 = normalizer * normalizer;
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * axis_size + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||||
@ -360,8 +360,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * axis_size + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the accumulators
|
// Allocate registers for the accumulators
|
||||||
@ -457,8 +457,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer2 = normalizer * normalizer;
|
float normalizer2 = normalizer * normalizer;
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * axis_size + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
@ -24,7 +24,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -62,7 +62,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * axis_size + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||||
@ -92,7 +92,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
@ -132,7 +132,7 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
out += gid * axis_size + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@ -165,8 +165,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * axis_size + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the computation and accumulators
|
// Allocate registers for the computation and accumulators
|
||||||
@ -233,8 +233,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer3 = normalizer * normalizer * normalizer;
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * axis_size + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
gx[i] = static_cast<T>(
|
gx[i] = static_cast<T>(
|
||||||
@ -270,8 +270,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
// Advance the input pointers
|
// Advance the input pointers
|
||||||
x += gid * axis_size + lid * N_READS;
|
x += gid * size_t(axis_size) + lid * N_READS;
|
||||||
g += gid * axis_size + lid * N_READS;
|
g += gid * size_t(axis_size) + lid * N_READS;
|
||||||
w += w_stride * lid * N_READS;
|
w += w_stride * lid * N_READS;
|
||||||
|
|
||||||
// Allocate registers for the accumulators
|
// Allocate registers for the accumulators
|
||||||
@ -337,8 +337,8 @@ template <typename T, int N_READS = RMS_N_READS>
|
|||||||
float normalizer3 = normalizer * normalizer * normalizer;
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
// Write the outputs
|
// Write the outputs
|
||||||
gx += gid * axis_size + lid * N_READS;
|
gx += gid * size_t(axis_size) + lid * N_READS;
|
||||||
gw += gid * axis_size + lid * N_READS;
|
gw += gid * size_t(axis_size) + lid * N_READS;
|
||||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
|
|
||||||
AccT ld[N_READS];
|
AccT ld[N_READS];
|
||||||
|
|
||||||
in += gid * axis_size + lid * N_READS;
|
in += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
ld[i] = AccT(in[i]);
|
ld[i] = AccT(in[i]);
|
||||||
@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
normalizer = 1 / local_normalizer[0];
|
normalizer = 1 / local_normalizer[0];
|
||||||
|
|
||||||
// Normalize and write to the output
|
// Normalize and write to the output
|
||||||
out += gid * axis_size + lid * N_READS;
|
out += gid * size_t(axis_size) + lid * N_READS;
|
||||||
if (lid * N_READS + N_READS <= axis_size) {
|
if (lid * N_READS + N_READS <= axis_size) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
out[i] = T(ld[i] * normalizer);
|
out[i] = T(ld[i] * normalizer);
|
||||||
@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
in += gid * axis_size;
|
in += gid * size_t(axis_size);
|
||||||
|
|
||||||
constexpr int SIMD_SIZE = 32;
|
constexpr int SIMD_SIZE = 32;
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|||||||
|
|
||||||
// Finally given the normalizer and max value we can directly write the
|
// Finally given the normalizer and max value we can directly write the
|
||||||
// softmax output
|
// softmax output
|
||||||
out += gid * axis_size;
|
out += gid * size_t(axis_size);
|
||||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||||
r++) {
|
r++) {
|
||||||
int offset = r * lsize * N_READS + lid * N_READS;
|
int offset = r * lsize * N_READS + lid * N_READS;
|
||||||
|
@ -10,6 +10,18 @@ template <typename T, typename Op>
|
|||||||
d[index] = Op()(a[index], b[index], c[index]);
|
d[index] = Op()(a[index], b[index], c[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void ternary_v2(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
d[offset] = Op()(a[offset], b[offset], c[offset]);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void ternary_g_nd1(
|
[[kernel]] void ternary_g_nd1(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#define instantiate_ternary_all(op, tname, type) \
|
#define instantiate_ternary_all(op, tname, type) \
|
||||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||||
|
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||||
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
|
||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
|
@ -8,6 +8,16 @@ template <typename T, typename Op>
|
|||||||
out[index] = Op()(in[index]);
|
out[index] = Op()(in[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void unary_v2(
|
||||||
|
device const T* in,
|
||||||
|
device T* out,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||||
|
out[offset] = Op()(in[offset]);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void unary_g(
|
[[kernel]] void unary_g(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
|
@ -5,8 +5,9 @@
|
|||||||
#include "mlx/backend/metal/kernels/unary_ops.h"
|
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/unary.h"
|
#include "mlx/backend/metal/kernels/unary.h"
|
||||||
|
|
||||||
#define instantiate_unary_all(op, tname, type) \
|
#define instantiate_unary_all(op, tname, type) \
|
||||||
instantiate_kernel("v" #op #tname, unary_v, type, op) \
|
instantiate_kernel("v" #op #tname, unary_v, type, op) \
|
||||||
|
instantiate_kernel("v2" #op #tname, unary_v2, type, op) \
|
||||||
instantiate_kernel("g" #op #tname, unary_g, type, op)
|
instantiate_kernel("g" #op #tname, unary_g, type, op)
|
||||||
|
|
||||||
#define instantiate_unary_float(op) \
|
#define instantiate_unary_float(op) \
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
@ -32,6 +32,7 @@ void ternary_op_gpu_inplace(
|
|||||||
auto& strides_c = strides[2];
|
auto& strides_c = strides[2];
|
||||||
auto& strides_out = strides[3];
|
auto& strides_out = strides[3];
|
||||||
|
|
||||||
|
bool use_2d = out.data_size();
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -40,6 +41,8 @@ void ternary_op_gpu_inplace(
|
|||||||
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
|
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
|
||||||
kname << shape.size();
|
kname << shape.size();
|
||||||
}
|
}
|
||||||
|
} else if (use_2d) {
|
||||||
|
kname << "v2";
|
||||||
} else {
|
} else {
|
||||||
kname << "v";
|
kname << "v";
|
||||||
}
|
}
|
||||||
|
@ -25,11 +25,14 @@ void unary_op_gpu_inplace(
|
|||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
|
bool use_2d = nthreads > UINT32_MAX;
|
||||||
|
std::string kernel_name =
|
||||||
|
(contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out);
|
||||||
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
|
||||||
|
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
|
@ -104,6 +104,35 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
|||||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
MTL::Size get_2d_grid_dims(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<size_t>& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
return MTL::Size(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
inline NS::String* make_string(std::ostringstream& os) {
|
inline NS::String* make_string(std::ostringstream& os) {
|
||||||
std::string string = os.str();
|
std::string string = os.str();
|
||||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||||
|
@ -273,7 +273,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
// Check for the number of indices passed
|
// Check for the number of indices passed
|
||||||
if (non_none_indices > src.ndim()) {
|
if (non_none_indices > src.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,7 +585,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
|
|
||||||
if (non_none_indices > src.ndim()) {
|
if (non_none_indices > src.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -840,7 +840,7 @@ auto mlx_slice_update(
|
|||||||
// Dimension check
|
// Dimension check
|
||||||
if (non_none_indices > src.ndim()) {
|
if (non_none_indices > src.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user