mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
format
This commit is contained in:
parent
14531cb14f
commit
91817a165b
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
|||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
|
@ -43,8 +43,8 @@ void copy_contiguous(
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
|
@ -57,7 +57,6 @@ struct CastOp<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Return an iterator that cast the value to DstT using CastOp.
|
// Return an iterator that cast the value to DstT using CastOp.
|
||||||
template <typename DstT, typename Iterator>
|
template <typename DstT, typename Iterator>
|
||||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
|
@ -167,7 +167,8 @@ inline std::tuple<dim3, uint> get_launch_args(
|
|||||||
const array& arr,
|
const array& arr,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1) {
|
||||||
return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
return get_launch_args(
|
||||||
|
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
|||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
|
@ -28,8 +28,8 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||||
std::is_same_v<Op, Sigmoid> ||
|
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
||||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
std::is_same_v<Op, Rsqrt>) {
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
Loading…
Reference in New Issue
Block a user