mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Share more common code in Compiled (#2240)
* Share more common code in Compiled * Remove build_lib_name
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user