Remove backend apis

This commit is contained in:
Cheng
2025-07-17 23:46:09 +00:00
parent ad44c4bcd9
commit 6571df6ad7
4 changed files with 57 additions and 145 deletions

View File

@@ -212,7 +212,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev cudnn9-cuda-12 sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env python3 -m venv env
source env/bin/activate source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \

View File

@@ -14,30 +14,17 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDNN_ERROR
#undef CHECK_CUDNN_FRONTEND_ERROR
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
using namespace cudnn_frontend; using namespace cudnn_frontend;
#define CHECK_CUDNN_FRONTEND_ERROR(cmd) \ #define CHECK_CUDNN_FE_ERROR(cmd) \
if (cmd.is_bad()) { \ if (cmd.is_bad()) { \
throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ throw std::runtime_error(fmt::format("{} failed.", #cmd)); \
} }
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
auto swapaxes(const array& in, int axis1, int axis2) { auto swapaxes(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim()); std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
@@ -63,7 +50,8 @@ class Convolution {
const std::vector<int64_t>& output_shape, const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& output_strides, const std::vector<int64_t>& output_strides,
const std::vector<int64_t>& stride, const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding, const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation, const std::vector<int64_t>& dilation,
int groups) int groups)
: handle_(device.cudnn_handle()) { : handle_(device.cudnn_handle()) {
@@ -73,119 +61,31 @@ class Convolution {
graph_.set_io_data_type(cudnn_type) graph_.set_io_data_type(cudnn_type)
.set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type); .set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type);
input_attr_ = graph_.tensor(graph::Tensor_attributes() input_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_name("input")
.set_dim(input_shape) .set_dim(input_shape)
.set_stride(input_strides)); .set_stride(input_strides));
filter_attr_ = graph_.tensor(graph::Tensor_attributes() filter_attr_ = graph_.tensor(graph::Tensor_attributes()
.set_name("filter")
.set_dim(filter_shape) .set_dim(filter_shape)
.set_stride(filter_strides)); .set_stride(filter_strides));
auto conv_options = graph::Conv_fprop_attributes() auto conv_options = graph::Conv_fprop_attributes()
.set_padding(padding) .set_pre_padding(padding_lo)
.set_post_padding(padding_hi)
.set_stride(stride) .set_stride(stride)
.set_dilation(dilation); .set_dilation(dilation);
output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options); output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options);
output_attr_->set_output(true); output_attr_->set_output(true);
output_attr_->set_data_type(cudnn_type);
output_attr_->set_dim(output_shape);
output_attr_->set_stride(output_strides);
CHECK_CUDNN_FRONTEND_ERROR(graph_.validate()); CHECK_CUDNN_FE_ERROR(graph_.validate());
CHECK_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(handle_)); CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_));
CHECK_CUDNN_FRONTEND_ERROR(graph_.create_execution_plans({HeurMode_t::A})); CHECK_CUDNN_FE_ERROR(graph_.create_execution_plans({HeurMode_t::A}));
CHECK_CUDNN_FRONTEND_ERROR(graph_.check_support(handle_)); CHECK_CUDNN_FE_ERROR(graph_.check_support(handle_));
CHECK_CUDNN_FRONTEND_ERROR(graph_.build_plans(handle_)); CHECK_CUDNN_FE_ERROR(graph_.build_plans(handle_));
CHECK_CUDNN_FE_ERROR(graph_.get_workspace_size(workspace_size_));
#if 0
int ndim = input_shape.size();
CHECK_CUDNN_ERROR(cudnnCreateTensorDescriptor(&input_desc_));
CHECK_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
input_desc_,
cudnn_type,
ndim,
input_shape.data(),
input_strides.data()));
CHECK_CUDNN_ERROR(cudnnCreateFilterDescriptor(&filter_desc_));
CHECK_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
filter_desc_,
cudnn_type,
CUDNN_TENSOR_NCHW,
ndim,
filter_shape.data()));
CHECK_CUDNN_ERROR(cudnnCreateTensorDescriptor(&output_desc_));
CHECK_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
output_desc_,
cudnn_type,
ndim,
output_shape.data(),
output_strides.data()));
CHECK_CUDNN_ERROR(cudnnCreateConvolutionDescriptor(&conv_desc_));
CHECK_CUDNN_ERROR(cudnnSetConvolutionGroupCount(conv_desc_, groups));
CHECK_CUDNN_ERROR(cudnnSetConvolutionNdDescriptor(
conv_desc_,
ndim - 2,
padding.data(),
stride.data(),
dilation.data(),
CUDNN_CROSS_CORRELATION,
is_half ? CUDNN_DATA_FLOAT : cudnn_type));
if (is_half) {
CHECK_CUDNN_ERROR(
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
} else if (dtype == float32) {
CHECK_CUDNN_ERROR(
cudnnSetConvolutionMathType(conv_desc_, CUDNN_FMA_MATH));
} else {
CHECK_CUDNN_ERROR(
cudnnSetConvolutionMathType(conv_desc_, CUDNN_DEFAULT_MATH));
}
std::vector<int> expected_output_shape(ndim);
CHECK_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
conv_desc_,
input_desc_,
filter_desc_,
ndim,
expected_output_shape.data()));
std::cout << "expected_output_shape: " << expected_output_shape
<< std::endl;
cudnnConvolutionFwdAlgoPerf_t results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
int count;
CHECK_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
handle_,
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
std::size(results),
&count,
results));
for (int i = 0; i < count; ++i) {
if (results[i].status == CUDNN_STATUS_SUCCESS) {
algo_ = results[i].algo;
std::cout << "Found algorithm" << std::endl;
break;
}
}
CHECK_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
handle_,
input_desc_,
filter_desc_,
conv_desc_,
output_desc_,
algo_,
&workspace_size_));
#endif
}
~Convolution() {
#if 0
cudnnDestroyTensorDescriptor(input_desc_);
cudnnDestroyFilterDescriptor(filter_desc_);
cudnnDestroyTensorDescriptor(output_desc_);
cudnnDestroyConvolutionDescriptor(conv_desc_);
#endif
} }
void run( void run(
@@ -208,25 +108,8 @@ class Convolution {
{output_attr_->get_uid(), output}}; {output_attr_->get_uid(), output}};
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
CHECK_CUDNN_ERROR(cudnnSetStream(handle_, encoder.stream())); CHECK_CUDNN_FE_ERROR(
CHECK_CUDNN_FRONTEND_ERROR(
graph_.execute(handle_, ptr_map, workspace.data<void>())); graph_.execute(handle_, ptr_map, workspace.data<void>()));
#if 0
CHECK_CUDNN_ERROR(cudnnConvolutionForward(
handle_,
&alpha,
input_desc_,
input,
filter_desc_,
filter,
conv_desc_,
algo_,
workspace.data<void>(),
workspace_size_,
&beta,
output_desc_,
output));
#endif
} }
private: private:
@@ -264,7 +147,7 @@ class Convolution {
std::shared_ptr<graph::Tensor_attributes> input_attr_; std::shared_ptr<graph::Tensor_attributes> input_attr_;
std::shared_ptr<graph::Tensor_attributes> filter_attr_; std::shared_ptr<graph::Tensor_attributes> filter_attr_;
std::shared_ptr<graph::Tensor_attributes> output_attr_; std::shared_ptr<graph::Tensor_attributes> output_attr_;
size_t workspace_size_{0}; int64_t workspace_size_{0};
}; };
} // namespace cu } // namespace cu
@@ -295,6 +178,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
output_strides, output_strides,
std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()), std::vector<int64_t>(kernel_strides_.begin(), kernel_strides_.end()),
std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()), std::vector<int64_t>(padding_lo_.begin(), padding_lo_.end()),
std::vector<int64_t>(padding_hi_.begin(), padding_hi_.end()),
std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()), std::vector<int64_t>(kernel_dilation_.begin(), kernel_dilation_.end()),
groups_); groups_);
conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>()); conv.run(encoder, in.data<void>(), wt.data<void>(), out.data<void>());

View File

@@ -9,12 +9,23 @@
#include <future> #include <future>
#include <unordered_set> #include <unordered_set>
namespace mlx::core { namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER // Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255 // This should be less than 255
constexpr int default_max_nodes_per_graph = 20; constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
int cuda_graph_cache_size() { int cuda_graph_cache_size() {
static int cache_size = []() { static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
@@ -22,7 +33,7 @@ int cuda_graph_cache_size() {
return cache_size; return cache_size;
} }
namespace cu { } // namespace
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute( CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
@@ -42,11 +53,11 @@ Device::Device(int device) : device_(device) {
make_current(); make_current();
cublasLtCreate(&lt_); cublasLtCreate(&lt_);
// The cudnn handle is used by Convolution. // The cudnn handle is used by Convolution.
cudnnCreate(&cudnn_); CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
} }
Device::~Device() { Device::~Device() {
cudnnDestroy(cudnn_); CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
cublasLtDestroy(lt_); cublasLtDestroy(lt_);
} }
@@ -180,6 +191,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream()));
} }
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) { void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
@@ -334,6 +346,4 @@ CommandEncoder& get_command_encoder(Stream s) {
return device(s.device).get_command_encoder(s); return device(s.device).get_command_encoder(s);
} }
} // namespace cu } // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -3688,7 +3688,17 @@ TEST_CASE("test conv1d") {
} }
TEST_CASE("test conv2d") { TEST_CASE("test conv2d") {
array in = zeros({1, 2, 2, 3}, float32); auto in = array(
{0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728},
{1, 2, 2, 2});
std::pair<int, int> kernel{2, 2}; std::pair<int, int> kernel{2, 2};
std::pair<int, int> stride{1, 1}; std::pair<int, int> stride{1, 1};
@@ -3697,7 +3707,15 @@ TEST_CASE("test conv2d") {
{ {
int groups = 1; int groups = 1;
array wt = ones({1, 2, 2, 3}, float32); auto wt = array(
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
1.6598022, 0.74204415},
{4, 2, 2, 2});
auto expected = auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4}); array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});