mlx/mlx/backend/metal/reduce.h

40 lines
920 B
C
Raw Normal View History

// Copyright @ 2023 - 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/stream.h"
namespace mlx::core {
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s);
} // namespace mlx::core