mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-12 07:18:52 +08:00
@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b) {
|
const array& b) {
|
||||||
if (a.ndim() == 2) {
|
if (a.ndim() == 2) {
|
||||||
return {{1}, {0}, {0}};
|
return {Shape{1}, Strides{0}, Strides{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|||||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
collapse_batches(const array& a, const array& b, const array& c) {
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
if (a.ndim() == 2) {
|
if (a.ndim() == 2) {
|
||||||
return {{1}, {0}, {0}, {0}};
|
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
|||||||
@@ -382,20 +382,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (op_graph) {
|
if (op_graph) {
|
||||||
// Setup inputs and outputs.
|
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
|
||||||
|
|
||||||
// Find a plan for the graph and execute it.
|
// Find a plan for the graph and execute it.
|
||||||
auto plan = find_cudnn_plan_from_op_graph(
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
if (!plan) {
|
if (plan) {
|
||||||
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
// Setup inputs and outputs.
|
||||||
}
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
|
||||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
conv_cache().emplace(
|
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
conv_cache().emplace(
|
||||||
return;
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -210,6 +210,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
|||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
cudnn_frontend::OperationGraph& op_graph) {
|
cudnn_frontend::OperationGraph& op_graph) {
|
||||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||||
|
if (engine_configs.empty()) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user