mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 10:02:12 +08:00
40 lines
920 B
C
40 lines
920 B
C
![]() |
// Copyright @ 2023 - 2024 Apple Inc.
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "mlx/backend/common/reduce.h"
|
||
|
#include "mlx/backend/metal/device.h"
|
||
|
#include "mlx/stream.h"
|
||
|
|
||
|
namespace mlx::core {
|
||
|
|
||
|
void all_reduce_dispatch(
|
||
|
const array& in,
|
||
|
array& out,
|
||
|
const std::string& op_name,
|
||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||
|
metal::Device& d,
|
||
|
const Stream& s);
|
||
|
|
||
|
void row_reduce_general_dispatch(
|
||
|
const array& in,
|
||
|
array& out,
|
||
|
const std::string& op_name,
|
||
|
const ReductionPlan& plan,
|
||
|
const std::vector<int>& axes,
|
||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||
|
metal::Device& d,
|
||
|
const Stream& s);
|
||
|
|
||
|
void strided_reduce_general_dispatch(
|
||
|
const array& in,
|
||
|
array& out,
|
||
|
const std::string& op_name,
|
||
|
const ReductionPlan& plan,
|
||
|
const std::vector<int>& axes,
|
||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||
|
metal::Device& d,
|
||
|
const Stream& s);
|
||
|
|
||
|
} // namespace mlx::core
|