mlx/mlx/backend/metal/reduce.h
2024-11-04 22:25:16 -08:00

42 lines
914 B
C++

// Copyright @ 2023 - 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/stream.h"
namespace mlx::core {
using metal::CommandEncoder;
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const ReductionPlan& plan,
const std::vector<int>& axes,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s);
} // namespace mlx::core