|
|
| #ifndef MARLIN_NAMESPACE_NAME |
| #define MARLIN_NAMESPACE_NAME marlin |
| #endif |
|
|
| #include "marlin.cuh" |
| #include "marlin_dtypes.cuh" |
| #include "core/scalar_type.hpp" |
|
|
| #define MARLIN_KERNEL_PARAMS \ |
| const int4 *__restrict__ A, const int4 *__restrict__ B, \ |
| int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ |
| const int4 *__restrict__ scales_ptr, \ |
| const uint16_t *__restrict__ scale2_ptr, \ |
| const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ |
| int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ |
| bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem |
|
|
| namespace MARLIN_NAMESPACE_NAME { |
| template <typename scalar_t, |
| const vllm::ScalarTypeId w_type_id, |
| const int threads, |
| const int thread_m_blocks, |
| |
| |
| const int thread_n_blocks, |
| const int thread_k_blocks, |
| const bool m_block_size_8, |
| |
| const int stages, |
| |
| const int group_blocks, |
| |
| const bool is_zp_float |
| > |
| __global__ void Marlin(MARLIN_KERNEL_PARAMS); |
|
|
| } |
|
|