mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	tmp
This commit is contained in:
		| @@ -90,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") | ||||
| target_compile_options(mlx | ||||
|                        PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") | ||||
|  | ||||
| # Keep ptx around for inspection | ||||
| target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>") | ||||
|  | ||||
| # Enable calling host constexpr functions from device. This is needed because | ||||
| # the constexpr version of isnan is host only. | ||||
| target_compile_options( | ||||
|   | ||||
| @@ -5,8 +5,29 @@ | ||||
| #include "mlx/backend/cuda/steel/gemm.cuh" | ||||
| #include "mlx/dtype_utils.h" | ||||
|  | ||||
| #include <iostream> | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename Kernel> | ||||
| static void configure_smem(Kernel kernel, int SM) { | ||||
|   static bool done = false; | ||||
|   if (done) { | ||||
|     return; | ||||
|   } | ||||
|   std::cout << "configuring" << std::endl; | ||||
|   cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM); | ||||
|   cudaFuncSetAttribute( | ||||
|       kernel, | ||||
|       cudaFuncAttributePreferredSharedMemoryCarveout, | ||||
|       cudaSharedmemCarveoutMaxShared); | ||||
|   done = true; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void simple_gemm( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
| @@ -23,17 +44,20 @@ void simple_gemm( | ||||
|     constexpr int BM = 128; | ||||
|     constexpr int BN = 128; | ||||
|     constexpr int BK = 32; | ||||
|     constexpr int PIPE = 3; | ||||
|     constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK); | ||||
|     constexpr int WM = 2; | ||||
|     constexpr int WN = 4; | ||||
|  | ||||
|     auto kernel = ab_t_aligned<DataType, BM, BN, BK>; | ||||
|     cudaFuncSetAttribute( | ||||
|         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); | ||||
|     auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>; | ||||
|     configure_smem(kernel, SM); | ||||
|  | ||||
|     dim3 grid(N / BN, M / BM); | ||||
|     enc.add_kernel_node( | ||||
|         kernel, | ||||
|         grid, | ||||
|         8 * WARP_SIZE, | ||||
|         4 * sizeof(DataType) * (BM * BK + BN * BK), | ||||
|         WM * WN * WARP_SIZE, | ||||
|         SM, | ||||
|         a.data<DataType>(), | ||||
|         b.data<DataType>(), | ||||
|         out.data<DataType>(), | ||||
|   | ||||
| @@ -16,6 +16,11 @@ namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| int get_test_gemm() { | ||||
|   static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0); | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| std::tuple<bool, int64_t, array> | ||||
| check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { | ||||
|   auto stx = arr.strides()[arr.ndim() - 2]; | ||||
| @@ -99,15 +104,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   } | ||||
|  | ||||
|   if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && | ||||
|       b_transposed && batch_count == 1 && | ||||
|       env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) { | ||||
|       b_transposed && batch_count == 1 && get_test_gemm() == 1) { | ||||
|     cu::simple_gemm(a, b, out, M, N, K, encoder); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && | ||||
|       b_transposed && batch_count == 1 && | ||||
|       env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 2) { | ||||
|       b_transposed && batch_count == 1 && get_test_gemm() == 2) { | ||||
|     cu::cutlass_gemm(a, b, out, M, N, K, encoder); | ||||
|     return; | ||||
|   } | ||||
|   | ||||
| @@ -8,20 +8,19 @@ template <typename T, int BM, int BN, int BK, int WM, int WN> | ||||
| __device__ inline void gemm_ab_t( | ||||
|     RegisterTile<float, BM / WM, BN / WN>& C, | ||||
|     SharedTile<T, BM, BK>& As, | ||||
|     SharedTile<T, BM, BK>& Bs, | ||||
|     int lane_row_a, | ||||
|     int lane_row_b, | ||||
|     int lane_col) { | ||||
|     SharedTile<T, BN, BK>& Bs, | ||||
|     RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a, | ||||
|     RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) { | ||||
|   RegisterTile<T, BM / WM, 16> A[2]; | ||||
|   RegisterTile<T, BN / WN, 16> B[2]; | ||||
|  | ||||
|   A[0].load(As, As.base_addr(), lane_row_a, lane_col); | ||||
|   B[0].load(Bs, Bs.base_addr(), lane_row_b, lane_col); | ||||
|   rloader_a.load(A[0], As.base_addr(), 0); | ||||
|   rloader_b.load(B[0], Bs.base_addr(), 0); | ||||
|  | ||||
|   MLX_UNROLL | ||||
|   for (int k = 1; k < BK / 16; k++) { | ||||
|     A[k & 1].load(As, As.base_addr(), lane_row_a, lane_col + k * 16); | ||||
|     B[k & 1].load(Bs, Bs.base_addr(), lane_row_b, lane_col + k * 16); | ||||
|     rloader_a.load(A[k & 1], As.base_addr(), k); | ||||
|     rloader_b.load(B[k & 1], Bs.base_addr(), k); | ||||
|  | ||||
|     mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]); | ||||
|   } | ||||
| @@ -33,25 +32,91 @@ __device__ inline void gemm_ab_t( | ||||
|  * | ||||
|  * Computes A @ B.T when A and B are all aligned with the block sizes. | ||||
|  */ | ||||
| template <typename T, int BM, int BN, int BK> | ||||
| __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   constexpr int WARPS_M = 4; | ||||
|   constexpr int WARPS_N = 2; | ||||
|   constexpr int NUM_WARPS = WARPS_M * WARPS_N; | ||||
|   constexpr int WARP_STEP_M = BM / WARPS_M; | ||||
|   constexpr int WARP_STEP_N = BN / WARPS_N; | ||||
|   constexpr int PIPE = 4; | ||||
| // template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE> | ||||
| //__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1) | ||||
| // void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
| //   constexpr int NUM_WARPS = WM * WN; | ||||
| //   constexpr int WARP_STEP_M = BM / WM; | ||||
| //   constexpr int WARP_STEP_N = BN / WN; | ||||
| // | ||||
| //   // Precompute some offsets for each thread | ||||
| //   const int warpid = threadIdx.x / 32; | ||||
| //   const int laneid = threadIdx.x % 32; | ||||
| //   const int wm = warpid / WN; | ||||
| //   const int wn = warpid % WN; | ||||
| //   const int offset_m = wm * WARP_STEP_M; | ||||
| //   const int offset_n = wn * WARP_STEP_N; | ||||
| // | ||||
| //   // Allocate shared memory | ||||
| //   extern __shared__ char shmem[]; | ||||
| //   SharedTile<T, BM, BK>(&as)[PIPE] = | ||||
| //       *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]); | ||||
| //   SharedTile<T, BN, BK>(&bs)[PIPE] = | ||||
| //       *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]); | ||||
| // | ||||
| //   // Move the global pointers to the tile | ||||
| //   a += blockIdx.y * BM * K; | ||||
| //   b += blockIdx.x * BN * K; | ||||
| //   y += blockIdx.y * BM * N + blockIdx.x * BN; | ||||
| // | ||||
| //   // Make the loaders to/from SMEM | ||||
| //   SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K); | ||||
| //   SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K); | ||||
| //   RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid); | ||||
| //   RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid); | ||||
| // | ||||
| //   // Start the SM pipeline | ||||
| //   MLX_UNROLL | ||||
| //   for (int i = 0; i < PIPE - 1; i++) { | ||||
| //     sloader_a.load_async(as[i].base_addr()); | ||||
| //     sloader_b.load_async(bs[i].base_addr()); | ||||
| //     cp_async_commit(); | ||||
| //     sloader_a.next(); | ||||
| //     sloader_b.next(); | ||||
| //   } | ||||
| // | ||||
| //   // Allocate and zero the MMA accumulator | ||||
| //   RegisterTile<float, BM / WM, BN / WN> C; | ||||
| //   C.fill(0); | ||||
| // | ||||
| //   // Matmul loop | ||||
| //   int num_blocks = K / BK; | ||||
| //   int sread = 0; | ||||
| //   int swrite = PIPE - 1; | ||||
| //   for (int i = 0; i < num_blocks; i++) { | ||||
| //     cp_async_wait<PIPE - 1>(); | ||||
| // | ||||
| //     gemm_ab_t<T, BM, BN, BK, WM, WN>( | ||||
| //         C, as[sread], bs[sread], rloader_a, rloader_b); | ||||
| // | ||||
| //     sloader_a.load_async(as[swrite].base_addr()); | ||||
| //     sloader_b.load_async(bs[swrite].base_addr()); | ||||
| //     cp_async_commit(); | ||||
| //     sloader_a.next(i + PIPE < num_blocks); | ||||
| //     sloader_b.next(i + PIPE < num_blocks); | ||||
| // | ||||
| //     swrite = sread; | ||||
| //     sread = (sread + 1) % PIPE; | ||||
| //   } | ||||
| // | ||||
| //   C.store_global(y, N, offset_m, offset_n); | ||||
| // } | ||||
|  | ||||
| template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE> | ||||
| __global__ __launch_bounds__( | ||||
|     WM* WN* WARP_SIZE, | ||||
|     1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   constexpr int NUM_WARPS = WM * WN; | ||||
|   constexpr int WARP_STEP_M = BM / WM; | ||||
|   constexpr int WARP_STEP_N = BN / WN; | ||||
|  | ||||
|   // Precompute some offsets for each thread | ||||
|   const int warpid = threadIdx.x / 32; | ||||
|   const int laneid = threadIdx.x % 32; | ||||
|   const int wm = warpid / WARPS_N; | ||||
|   const int wn = warpid % WARPS_N; | ||||
|   const int wm = warpid / WN; | ||||
|   const int wn = warpid % WN; | ||||
|   const int offset_m = wm * WARP_STEP_M; | ||||
|   const int offset_n = wn * WARP_STEP_N; | ||||
|   const int lane_row_a = offset_m + (laneid & 15); | ||||
|   const int lane_row_b = offset_n + (laneid & 15); | ||||
|   const int lane_col = (laneid >> 4) << 3; | ||||
|  | ||||
|   // Allocate shared memory | ||||
|   extern __shared__ char shmem[]; | ||||
| @@ -65,34 +130,59 @@ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { | ||||
|   b += blockIdx.x * BN * K; | ||||
|   y += blockIdx.y * BM * N + blockIdx.x * BN; | ||||
|  | ||||
|   // Make the loaders to/from SMEM | ||||
|   using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>; | ||||
|   constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK; | ||||
|   const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW; | ||||
|   const int scol = | ||||
|       (threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD; | ||||
|   a += srow * K + scol; | ||||
|   b += srow * K + scol; | ||||
|   uint32_t sm_offsets[PIPE][2]; | ||||
|   MLX_UNROLL | ||||
|   for (int s = 0; s < PIPE; s++) { | ||||
|     sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol); | ||||
|     sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol); | ||||
|   } | ||||
|   RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid); | ||||
|   RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid); | ||||
|  | ||||
|   // Start the SM pipeline | ||||
|   MLX_UNROLL | ||||
|   for (int i = 0; i < PIPE - 1; i++) { | ||||
|     load_async<NUM_WARPS>(as[i], as[i].base_addr(), a + i * BK, K); | ||||
|     load_async<NUM_WARPS>(bs[i], bs[i].base_addr(), b + i * BK, K); | ||||
|   for (int s = 0; s < PIPE - 1; s++) { | ||||
|     MLX_UNROLL | ||||
|     for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) { | ||||
|       cp_async<16>(sm_offsets[s][0] + l * SSTEP, a); | ||||
|       cp_async<16>(sm_offsets[s][1] + l * SSTEP, b); | ||||
|       a += sloader::STEP_ROWS * K; | ||||
|       b += sloader::STEP_ROWS * K; | ||||
|     } | ||||
|     cp_async_commit(); | ||||
|   } | ||||
|  | ||||
|   // Allocate and zero the MMA accumulator | ||||
|   RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C; | ||||
|   RegisterTile<float, BM / WM, BN / WN> C; | ||||
|   C.fill(0); | ||||
|  | ||||
|   // Matmul loop | ||||
|   int num_blocks = K / BK; | ||||
|   int k_block = (PIPE - 1) * BK; | ||||
|   int sread = 0; | ||||
|   int swrite = PIPE - 1; | ||||
|   for (int i = 0; i < num_blocks; i++) { | ||||
|     cp_async_wait<PIPE - 2>(); | ||||
|     cp_async_wait<PIPE - 1>(); | ||||
|  | ||||
|     if (k_block < K) { | ||||
|       load_async<NUM_WARPS>(as[swrite], as[swrite].base_addr(), a + k_block, K); | ||||
|       load_async<NUM_WARPS>(bs[swrite], bs[swrite].base_addr(), b + k_block, K); | ||||
|     gemm_ab_t<T, BM, BN, BK, WM, WN>( | ||||
|         C, as[sread], bs[sread], rloader_a, rloader_b); | ||||
|  | ||||
|     if (false) { | ||||
|       MLX_UNROLL | ||||
|       for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) { | ||||
|         cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a); | ||||
|         cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b); | ||||
|         a += sloader::STEP_ROWS * K; | ||||
|         b += sloader::STEP_ROWS * K; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     gemm_ab_t<T, BM, BN, BK, WARPS_M, WARPS_N>( | ||||
|         C, as[sread], bs[sread], lane_row_a, lane_row_b, lane_col); | ||||
|  | ||||
|     cp_async_commit(); | ||||
|  | ||||
|     swrite = sread; | ||||
|   | ||||
| @@ -225,6 +225,8 @@ struct RegisterTile { | ||||
|  | ||||
| template <typename T, int ROWS_, int COLS_> | ||||
| struct SharedTile { | ||||
|   using value_type = T; | ||||
|  | ||||
|   static constexpr int ROWS = ROWS_; | ||||
|   static constexpr int COLS = COLS_; | ||||
|   static constexpr int TILES_X = COLS / 16; | ||||
| @@ -266,23 +268,26 @@ struct SharedTile { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Return the location of the element at (row, col) using the swizzle. | ||||
|   __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { | ||||
|   __device__ static inline uint32_t offset(int row, int col) { | ||||
|     if constexpr (swizzle_bytes > 0) { | ||||
|       static constexpr int swizzle_repeat = swizzle_bytes * 8; | ||||
|       static constexpr int subtile_cols = swizzle_bytes / sizeof(T); | ||||
|       const int outer_idx = col / subtile_cols; | ||||
|       const uint32_t addr = ptr + | ||||
|           sizeof(T) * | ||||
|               (outer_idx * ROWS * subtile_cols + row * subtile_cols + | ||||
|                col % subtile_cols); | ||||
|       const uint32_t addr = sizeof(T) * | ||||
|           (outer_idx * ROWS * subtile_cols + row * subtile_cols + | ||||
|            col % subtile_cols); | ||||
|       const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; | ||||
|       return (addr ^ swizzle); | ||||
|     } else { | ||||
|       return ptr + sizeof(T) * (row * COLS + col); | ||||
|       return sizeof(T) * (row * COLS + col); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Return the location of the element at (row, col) using the swizzle. | ||||
|   __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { | ||||
|     return ptr + offset(row, col); | ||||
|   } | ||||
|  | ||||
|   // Convenience functions to edit elements going through the swizzle. | ||||
|   __device__ inline T& operator()(int row, int col) { | ||||
|     return *ptr(data, row, col); | ||||
| @@ -313,6 +318,76 @@ struct SharedTile { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <int NUM_WARPS, typename Tile> | ||||
| struct SharedTileLoader { | ||||
|   using T = typename Tile::value_type; | ||||
|  | ||||
|   static constexpr int NUM_THREADS = NUM_WARPS * 32; | ||||
|   static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); | ||||
|   static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; | ||||
|   static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; | ||||
|   static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; | ||||
|   static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; | ||||
|  | ||||
|   const T* x_; | ||||
|   int N_; | ||||
|   uint32_t offset_; | ||||
|  | ||||
|   __device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) { | ||||
|     const int row = threadIdx.x / NUM_LOADS_PER_ROW; | ||||
|     const int col = threadIdx.x % NUM_LOADS_PER_ROW; | ||||
|  | ||||
|     x_ += row * N + col * ELEMENTS_PER_LOAD; | ||||
|     offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD); | ||||
|   } | ||||
|  | ||||
|   __device__ inline void load_async(uint32_t base_address) { | ||||
|     MLX_UNROLL | ||||
|     for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { | ||||
|       cp_async<16>( | ||||
|           base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS, | ||||
|           x_ + i * STEP_ROWS * N_); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __device__ inline void next() { | ||||
|     x_ += Tile::COLS; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename Tile> | ||||
| struct RegisterTileLoader { | ||||
|   using T = typename Tile::value_type; | ||||
|  | ||||
|   uint32_t offset_[Tile::COLS / 16]; | ||||
|  | ||||
|   __device__ RegisterTileLoader(int offset_row, int laneid) { | ||||
|     const int row = offset_row + laneid & 15; | ||||
|     const int col = (laneid >> 4) << 3; | ||||
|  | ||||
|     MLX_UNROLL | ||||
|     for (int i = 0; i < Tile::COLS / 16; i++) { | ||||
|       offset_[i] = Tile::offset(row, col + i * 16); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   template <typename T, int ROWS, int COLS> | ||||
|   __device__ inline void | ||||
|   load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) { | ||||
|     constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y; | ||||
|     constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X; | ||||
|  | ||||
|     MLX_UNROLL | ||||
|     for (int i = 0; i < TILES_Y; i++) { | ||||
|       MLX_UNROLL | ||||
|       for (int j = 0; j < TILES_X; j++) { | ||||
|         x.data[i * TILES_X + j].load( | ||||
|             base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T)); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| /** | ||||
|  * Load the tile from global memory by loading 16 bytes at a time and storing | ||||
|  * them immediately. | ||||
|   | ||||
| @@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) { | ||||
| #if defined(MLX_CUDA_SM_80_ENABLED) | ||||
|   if constexpr (N == 16) { | ||||
|     asm volatile( | ||||
|         "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), | ||||
|         "cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), | ||||
|         "l"(reinterpret_cast<const int4*>(x))); | ||||
|   } else if constexpr (N == 8) { | ||||
|     asm volatile( | ||||
|         "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), | ||||
|         "cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), | ||||
|         "l"(reinterpret_cast<const int2*>(x))); | ||||
|   } else if constexpr (N == 4) { | ||||
|     asm volatile( | ||||
|         "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), | ||||
|         "cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), | ||||
|         "l"(reinterpret_cast<const int*>(x))); | ||||
|   } | ||||
| #endif | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos