2024-03-27 07:35:34 +08:00
|
|
|
// Copyright @ 2023 - 2024 Apple Inc.
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "mlx/backend/common/reduce.h"
|
|
|
|
#include "mlx/backend/metal/device.h"
|
|
|
|
#include "mlx/stream.h"
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
2024-04-11 12:45:31 +08:00
|
|
|
using metal::CommandEncoder;
|
|
|
|
|
2024-03-27 07:35:34 +08:00
|
|
|
void all_reduce_dispatch(
|
|
|
|
const array& in,
|
|
|
|
array& out,
|
|
|
|
const std::string& op_name,
|
2024-04-11 12:45:31 +08:00
|
|
|
CommandEncoder& compute_encoder,
|
2024-03-27 07:35:34 +08:00
|
|
|
metal::Device& d,
|
|
|
|
const Stream& s);
|
|
|
|
|
|
|
|
void row_reduce_general_dispatch(
|
|
|
|
const array& in,
|
|
|
|
array& out,
|
|
|
|
const std::string& op_name,
|
|
|
|
const ReductionPlan& plan,
|
|
|
|
const std::vector<int>& axes,
|
2024-04-11 12:45:31 +08:00
|
|
|
CommandEncoder& compute_encoder,
|
2024-03-27 07:35:34 +08:00
|
|
|
metal::Device& d,
|
|
|
|
const Stream& s);
|
|
|
|
|
|
|
|
void strided_reduce_general_dispatch(
|
|
|
|
const array& in,
|
|
|
|
array& out,
|
|
|
|
const std::string& op_name,
|
|
|
|
const ReductionPlan& plan,
|
|
|
|
const std::vector<int>& axes,
|
2024-04-11 12:45:31 +08:00
|
|
|
CommandEncoder& compute_encoder,
|
2024-03-27 07:35:34 +08:00
|
|
|
metal::Device& d,
|
|
|
|
const Stream& s);
|
|
|
|
|
|
|
|
} // namespace mlx::core
|