mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix
This commit is contained in:
parent
850ad01914
commit
f07eb684a6
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/ternary.h"
|
#include "mlx/backend/common/ternary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
@ -80,7 +79,6 @@ void ternary_op_gpu_inplace(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
const auto& c = inputs[2];
|
const auto& c = inputs[2];
|
||||||
@ -94,7 +92,7 @@ void ternary_op_gpu_inplace(
|
|||||||
encoder.set_input_array(c);
|
encoder.set_input_array(c);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
||||||
using DType = cuda_type_t<CTYPE>;
|
using DType = cuda_type_t<CTYPE>;
|
||||||
|
|
||||||
auto topt = get_ternary_op_type(a, b, c);
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
@ -110,7 +108,7 @@ void ternary_op_gpu_inplace(
|
|||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = &cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out, large);
|
get_launch_args(kernel, out, large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
|
Loading…
Reference in New Issue
Block a user