mlx/mlx/transforms_impl.h
Awni Hannun db487e6b1a format
2023-11-30 11:50:50 -08:00

18 lines
542 B
C++

// Copyright © 2023 Apple Inc.
namespace mlx::core::detail {
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
const std::vector<int>& in_axes);
std::vector<array> vmap_replace(
const std::vector<array>& inputs,
const std::vector<array>& s_inputs,
const std::vector<array>& s_outputs,
const std::vector<int>& in_axes,
const std::vector<int>& out_axes);
} // namespace mlx::core::detail