Browse Source

fix local memory promotion for large thread

tags/v1.2.0
dabaiji 5 years ago
parent
commit
2980041c17
4 changed files with 122 additions and 129 deletions
  1. +1
    -1
      src/poly/schedule_pass_gpu/register_memory_manager.cc
  2. +6
    -4
      src/poly/tiling/tiling_strategy_manager.h
  3. +53
    -28
      src/poly/tiling/tiling_strategy_manager_gpu.cc
  4. +62
    -96
      third_party/incubator-tvm/src/runtime/cuda/cuda_module.cc

+ 1
- 1
src/poly/schedule_pass_gpu/register_memory_manager.cc View File

@@ -356,7 +356,7 @@ void RegisterMemoryManager::IsOutofMemory(std::vector<BufferDefInfo> promoted_in
auto tensor_size = std::accumulate(box_sizes.begin(), box_sizes.end(), 1, std::multiplies<size_t>());
auto data_bytes = scop_info_.user_config_.GetDataType(promoted_info.tensor_id.get_name());
total_alloc_size += tensor_size * std::max<int>(1, data_bytes / BYTES_PER_REGISTER);
if (total_alloc_size * alloc_threads > MAX_REGISTER_PER_THREAD_BLOCK * REGISTER_ALLOC_RATIO) {
if (total_alloc_size * alloc_threads >= MAX_REGISTER_PER_THREAD_BLOCK * REGISTER_ALLOC_RATIO) {
memory_exceeding_ = true;
break;
}


+ 6
- 4
src/poly/tiling/tiling_strategy_manager.h View File

@@ -58,8 +58,10 @@ class TilingStrategy {

// gpu configs
int64_t warp_sizes_ = 32;
int64_t max_num_blocks_ = 256 * 256;
int64_t max_num_threads_ = 1024;
int64_t max_x_dim_block_ = pow(2, 31) - 1;
int64_t max_y_z_dim_block_ = 65535;
int64_t max_x_y_dim_thread_ = 1024;
int64_t max_z_dim_thread_ = 64;
size_t max_dim_ = 3;
int64_t max_elem_per_thread_ = 1024;
};
@@ -344,6 +346,8 @@ class GpuStrategy : public TilingStrategy {
// Step 3. Transform list of integer into string mapping config.
void SetMappingConfig();


int GetLocalAllocBufCount();
Template template_{Template::DEFAULT};
bool is_reduce_op_[TEMPLATE_BULK] = {false, false, true, true, true, false};

@@ -352,8 +356,6 @@ class GpuStrategy : public TilingStrategy {
std::vector<int64_t> thread_limit_;
std::vector<int64_t> block_cfg_;
std::vector<int64_t> thread_cfg_;
int64_t max_x_y_dim_thread_ = 1024;
int64_t max_z_dim_thread_ = 64;
int block_count_{0}; // number of mapped blocks
int64_t elem_per_thread_[3]{SpItemPerThread::AUTO};
int64_t min_elem_for_io_bound_ = 2;


+ 53
- 28
src/poly/tiling/tiling_strategy_manager_gpu.cc View File

@@ -173,13 +173,13 @@ void ReduceStrategy::AkgReduceLibStrategyOnGpu() {
int64_t min_blocks = square_thread ? 32 : 512;
int64_t min_elem_per_thread = use_local ? 2 : 8;
int64_t min_ty = 8;
if (total_injective_size * total_reduce_size / min_blocks / max_num_threads_ < min_elem_per_thread) {
if (total_injective_size * total_reduce_size / min_blocks / max_x_y_dim_thread_ < min_elem_per_thread) {
min_blocks = 32;
min_ty = square_thread ? min_ty : 1;
}

std::pair<int64_t, int64_t> tx_range{1, max_num_threads_};
std::pair<int64_t, int64_t> ty_range{1, max_num_threads_};
std::pair<int64_t, int64_t> tx_range{1, max_x_y_dim_thread_};
std::pair<int64_t, int64_t> ty_range{1, max_x_y_dim_thread_};
auto AlignToPowerOfTwo = [](int64_t original_factor) -> int64_t {
while ((original_factor) & (original_factor - 1)) {
--original_factor;
@@ -339,9 +339,9 @@ void ReduceStrategy::DealWith4DFusedReduce() {
continue;
}
axis->TileRestrainToSingleValue(CastIntToExpr(last_mod_value), TileLevel::CACHE1);
if (last_mod_value > max_num_threads_) {
if (last_mod_value > max_x_y_dim_thread_) {
LOG(WARNING) << "Cannot bind axis to " << last_mod_value << " threads, maximal thread number is "
<< max_num_threads_
<< max_x_y_dim_thread_
<< ". If fusing more than two axes together, footprint box calculated by isl may not be correct.";
continue;
}
@@ -376,6 +376,18 @@ void ReduceStrategy::DealWithPostReduceTensors() {
}
}

int GpuStrategy::GetLocalAllocBufCount () {
int count = 0;
for (auto &it : analyzer_->buf_info_) {
auto buf = it.second.get();
CHECK(buf);
if (buf->scope == TilingMemScope::MEM_SCOPE_LOCAL) {
count++;
}
}
return count;
}

void GpuStrategy::ApplyCustomConstraint() {
auto ParseBindingConstraint = [](const std::string constraint, size_t max_size) {
std::vector<std::string> sp = akg::common::Split(constraint, ",");
@@ -514,10 +526,19 @@ void GpuStrategy::AddGpuConstraint() {
axis->TileRestrainToSingleValue(axis->c1_constraints.tile_min_, TileLevel::CACHE0);
});
}
// TODO: This is a very naive strategy to avoid cuda launch out of resources
// and we should fix this in register memory promotion pass.
if (template_ != Template::REDUCTION && template_ != Template::ALL_REDUCE) {
auto local_buf_count = GetLocalAllocBufCount();
auto thread_size = std::accumulate(thread_cfg_.begin(), thread_cfg_.end(), 1, std::multiplies<int>());
if (local_buf_count >= 4 || local_buf_count * 4 * thread_size >= 65536) {
analyzer_->scop_info_.user_config_.SetUseRegisterMemory(false);
}
}
}

void GpuStrategy::InitMappingLimit() {
max_num_threads_ = analyzer_->scop_info_.user_config_.GetMaxElemPerThread();
max_x_y_dim_thread_ = analyzer_->scop_info_.user_config_.GetMaxElemPerThread();
DetermineTemplate();
std::stringstream ss;
reverse_binding_ = analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib() &&
@@ -565,18 +586,18 @@ void GpuStrategy::InitMappingLimit() {

if (template_ == Template::CUSTOM_CONFIG) {
auto block_config = analyzer_->scop_info_.user_config_.GetBlockConfig();
for (int i = block_config->bound - 1; i >= 0; --i) {
for (int i = 0; i < static_cast<int>(block_config->bound) - 1; ++i) {
if (i >= static_cast<int>(depth_)) {
continue;
break;
}
block_limit_.emplace_back(block_config->GetAt(i).second);
}
} else if (template_ <= Template::REDUCTION) {
block_limit_ = {max_num_blocks_, max_num_blocks_, max_num_blocks_};
block_limit_ = {max_x_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_};
} else if (template_ == Template::ALL_REDUCE && !analyzer_->scop_info_.user_config_.GetEnableAkgReduceLib()) {
block_limit_ = {1};
} else {
block_limit_ = {max_num_blocks_, max_num_blocks_, max_num_blocks_};
block_limit_ = {max_x_dim_block_, max_y_z_dim_block_, max_y_z_dim_block_};
}

std::vector<std::string> elem_cfg = common::Split(analyzer_->scop_info_.user_config_.GetElemPerThread(), " ");
@@ -617,13 +638,14 @@ void GpuStrategy::InnerThreadOuterBlock() {
auto block_dim = std::min(block_limit_.size(), max_dim_);

// tile from inner to outer and map to thread
analyzer_->GetTileLogger().AppendLine(GPU_MAPPING, "-----Map to thread-----");
size_t ori_size = pending_axes_.size();
size_t inner_dim = 0;
for (size_t i = 0; i < ori_size; ++i) {
TileAxis *axis;
int64_t shape;
std::tie(axis, shape) = pending_axes_[i];
int64_t rest_threads = std::min(max_num_threads_ / activated_threads, thread_limit_[thread_cfg_.size()]);
int64_t rest_threads = std::min(max_x_y_dim_thread_ / activated_threads, thread_limit_[thread_cfg_.size()]);
ss << "axis " << axis->index << "_" << axis->dim_axis << " shape = " << shape
<< ", rest_threads = " << rest_threads;
auto SkipMapping = [this, &axis, &shape, &ss, &inner_dim, &thread_dim]() {
@@ -642,14 +664,16 @@ void GpuStrategy::InnerThreadOuterBlock() {
if (axis->block_constraints.map_extent_ > 1) {
tile =
std::max(tile, std::max<int64_t>(ceil(static_cast<float>(shape) / axis->block_constraints.map_extent_), 1));
pending_axes_.push_back(std::make_pair(axis, std::max<int64_t>(ceil(static_cast<float>(shape) / tile), 1)));
ss << ", map to block.";
} else {
tile = std::min(tile, shape);
}
}
axis->TileRestrainLower(tile, TileLevel::CACHE1);
ss << ", tile = " << tile;
if (axis->block_constraints.map_extent_ > 1) {
pending_axes_.push_back(std::make_pair(axis, std::max<int64_t>(ceil(static_cast<float>(shape) / tile), 1)));
ss << ", map to block.";
}
analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss);
};

@@ -695,8 +719,9 @@ void GpuStrategy::InnerThreadOuterBlock() {
if (template_ == Template::PURE_ELEM) {
std::map<int64_t, std::vector<size_t>, std::greater<int64_t>> sorted_by_gcd;
for (size_t i = pending_axes_.size() - 1; i >= ori_size; --i) {
auto use = (max_num_blocks_ > 0 && pending_axes_[i].second > 0)
? TilingAnalyzer::FindDivisibleTilingFactor(max_num_blocks_, pending_axes_[i].second)
auto block_limit = i == 0 ? max_x_dim_block_ : max_y_z_dim_block_;
auto use = (block_limit > 0 && pending_axes_[i].second > 0)
? TilingAnalyzer::FindDivisibleTilingFactor(block_limit, pending_axes_[i].second)
: 1;
if (sorted_by_gcd.find(use) == sorted_by_gcd.end()) {
sorted_by_gcd[use] = {i};
@@ -726,14 +751,14 @@ void GpuStrategy::InnerThreadOuterBlock() {
}

// map outer band to block according to predefined indice
analyzer_->GetTileLogger().AppendLine(GPU_MAPPING, "-----Map to block-----");
for (const auto &i : indexing) {
TileAxis *axis;
int64_t shape;
std::tie(axis, shape) = pending_axes_[i];
auto idx = pending_axes_.size() - 1 - i;
idx = reverse_binding_ ? block_limit_.size() - 1 - idx : idx;
auto rest_blocks = std::min(max_num_blocks_ / activated_blocks, block_limit_[idx]);
rest_blocks = std::min(rest_blocks, axis->block_constraints.map_extent_);
auto idx = depth_ - 1 - (pending_axes_.size() - 1 - i);
idx = reverse_binding_ ? std::min(depth_, block_limit_.size()) - 1 - idx : idx;
auto rest_blocks = std::min(block_limit_[idx], axis->block_constraints.map_extent_);
ss << "axis " << axis->index << "_" << axis->dim_axis << " shape = " << shape << ", rest blocks = " << rest_blocks;
if (block_count_ >= static_cast<int>(block_dim)) {
ss << "-> No mapping.";
@@ -888,7 +913,7 @@ int64_t GpuStrategy::TileAfterThreadMapping(TileAxis *axis, size_t inner_dim, in
tile = thread_size;
ss << "tile = thread size, ";
} else {
auto block_dim = reverse_binding_ ? inner_dim : block_limit_.size() - 1 - inner_dim;
auto block_dim = reverse_binding_ ? block_limit_.size() - 1 - inner_dim : inner_dim;
int64_t least_blocks;
if (block_dim >= 0 && block_dim < block_limit_.size()) {
least_blocks = block_limit_[block_dim];
@@ -1038,7 +1063,7 @@ void GpuStrategy::InjectiveSpeedup() {
while (shape % lower != 0) {
--lower;
}
bool is_efficient = lower * 2 > thread_size || total_threads / thread_size * lower * 2 >= max_num_threads_;
bool is_efficient = lower * 2 > thread_size || total_threads / thread_size * lower * 2 >= max_x_y_dim_thread_;
if (is_efficient) {
ss << "align thread from " << thread_size << " to " << lower << " according to shape " << shape;
analyzer_->GetTileLogger().AppendLog(GPU_MAPPING, ss);
@@ -1058,8 +1083,8 @@ void GpuStrategy::InjectiveSpeedup() {
auto coaleasced_size = injective_axes.back()->thread_constraints.map_extent_;
auto proposal_blocks = coaleasced_size >= warp_sizes_ ? 256 : 512;
auto proposal_threads = (coaleasced_size >= warp_sizes_ && injective_axes.size() > 1U) ? 128
: coaleasced_size < max_num_threads_ ? 512
: max_num_threads_;
: coaleasced_size < max_x_y_dim_thread_ ? 512
: max_x_y_dim_thread_;
auto total_blocks = std::accumulate(block_cfg_.begin(), block_cfg_.end(), 1, std::multiplies<int>());
auto proposal_elem_per_thread = coaleasced_size < warp_sizes_ ? 1
: total_blocks < proposal_blocks * 8 ? min_elem_for_io_bound_
@@ -1226,7 +1251,7 @@ void GpuStrategy::GpuVectorBroadcastStrategy() {
int total_injective_size = 1;
auto broadcast_innermost = broadcast_idx_.find(original_shape.size() - 1) != broadcast_idx_.end();
for (size_t i = 0; i < original_shape.size(); ++i) {
if (original_shape[i] * possible_threads <= max_num_threads_) {
if (original_shape[i] * possible_threads <= max_x_y_dim_thread_) {
possible_threads *= original_shape[i];
}
auto rev_idx = original_shape.size() - 1 - i;
@@ -1235,7 +1260,7 @@ void GpuStrategy::GpuVectorBroadcastStrategy() {
coalesced_size = coalesced_size == 0 ? original_shape[i] : coalesced_size;
if (broadcast_innermost) {
auto prev_extent = axis->thread_constraints.map_extent_ > 0 ? axis->thread_constraints.map_extent_ : 1;
auto thread_limit = max_num_threads_ / prev_extent;
auto thread_limit = max_x_y_dim_thread_ / prev_extent;
auto coef = analyzer_->FindDivisibleTilingFactor(thread_limit, original_shape[i]);
axis->thread_constraints.map_extent_ = prev_extent * coef;
possible_threads = axis->thread_constraints.map_extent_;
@@ -1243,7 +1268,7 @@ void GpuStrategy::GpuVectorBroadcastStrategy() {
} else if (broadcast_innermost) {
auto prev_extent = axis->thread_constraints.map_extent_ > 0 ? axis->thread_constraints.map_extent_ : 1;
axis->thread_constraints.map_extent_ =
prev_extent * original_shape[i] <= max_num_threads_ ? prev_extent * original_shape[i] : prev_extent;
prev_extent * original_shape[i] <= max_x_y_dim_thread_ ? prev_extent * original_shape[i] : prev_extent;
possible_threads = axis->thread_constraints.map_extent_;
}
coalesced_size = coalesced_size == 0 ? 1 : coalesced_size;
@@ -1256,10 +1281,10 @@ void GpuStrategy::GpuVectorBroadcastStrategy() {
std::min(elem_per_thread, std::max<int>((fused_size_ / possible_threads / min_block + 1) / 2 * 2, 1));
ss << "thread for-loop speedup = " << axis->thread_constraints.item_process_;
} else if (total_injective_size > min_block) {
while (possible_threads % warp_sizes_ != 0 && possible_threads < max_num_threads_) {
while (possible_threads % warp_sizes_ != 0 && possible_threads < max_x_y_dim_thread_) {
++possible_threads;
}
int elem_per_block = std::max<int>(16 / (max_num_threads_ / possible_threads), 1);
int elem_per_block = std::max<int>(16 / (max_x_y_dim_thread_ / possible_threads), 1);
auto proposal_blocks = std::max(min_block, std::max<int>(fused_size_ / possible_threads / elem_per_block, 1));
axis->block_constraints.map_extent_ = proposal_blocks;
axis->thread_constraints.map_extent_ = possible_threads;


+ 62
- 96
third_party/incubator-tvm/src/runtime/cuda/cuda_module.cc View File

@@ -20,23 +20,25 @@
/*!
* \file cuda_module.cc
* 2020.09.19 - Modify operator() for kc_air.
* 2020.09.22 - Separate the implementation of KC and GPU.
* 2020.09.22 - Separate the implementation of KC and GPU.
*/
#include "cuda_module.h"

#include <tvm/runtime/registry.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <tvm/runtime/registry.h>

#include <array>
#include <string>
#include <mutex>
#include <string>
#include <unordered_map>
#include "cuda_common.h"
#include <vector>

#include "../file_util.h"
#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
#include "cuda_common.h"

namespace air {
namespace runtime {
@@ -47,8 +49,7 @@ namespace runtime {
// The modules will be lazily loaded
class CUDAModuleNode : public runtime::ModuleNode {
public:
explicit CUDAModuleNode(std::string data,
std::string fmt,
explicit CUDAModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source)
: data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) {
@@ -65,16 +66,11 @@ class CUDAModuleNode : public runtime::ModuleNode {
}
}

const char* type_key() const final {
return "cuda";
}
const char* type_key() const final { return "cuda"; }

PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final;
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

void SaveToFile(const std::string& file_name,
const std::string& format) final {
void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "cu") {
@@ -82,8 +78,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, cuda_source_);
} else {
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
@@ -106,11 +101,18 @@ class CUDAModuleNode : public runtime::ModuleNode {
}

// get a CUfunction from primary context in device_id
CUfunction GetFunc(int device_id, const std::string& func_name) {
CUfunction GetFunc(int device_id, const std::string& func_name, ThreadWorkLoad wl) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
CUjit_option options[1];
options[0] = CU_JIT_MAX_REGISTERS;
void* values[1];
long register_nums =
MAX_REGISTER_PER_THREAD_BLOCK / wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
values[0] = (void*)register_nums;
CUDA_DRIVER_CALL(
cuModuleLoadDataEx(&(module_[device_id]), data_.c_str(), 1, options, values));
}
CUresult result = CUDA_SUCCESS;
CUfunction func = nullptr;
@@ -122,11 +124,9 @@ class CUDAModuleNode : public runtime::ModuleNode {
#endif
}
if (result != CUDA_SUCCESS) {
const char *msg;
const char* msg;
cuGetErrorName(result, &msg);
LOG(FATAL)
<< "CUDAError: cuModuleGetFunction " << func_name
<< " failed with error: " << msg;
LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg;
}
#ifdef USE_KC_AIR
return func_[device_id];
@@ -135,9 +135,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
#endif
}
// get a global var from primary context in device_id
CUdeviceptr GetGlobal(int device_id,
const std::string& global_name,
size_t expect_nbytes) {
CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
@@ -146,15 +144,12 @@ class CUDAModuleNode : public runtime::ModuleNode {
CUdeviceptr global;
size_t nbytes;

CUresult result = cuModuleGetGlobal(&global, &nbytes,
module_[device_id], global_name.c_str());
CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str());
CHECK_EQ(nbytes, expect_nbytes);
if (result != CUDA_SUCCESS) {
const char *msg;
const char* msg;
cuGetErrorName(result, &msg);
LOG(FATAL)
<< "CUDAError: cuModuleGetGlobal " << global_name
<< " failed with error: " << msg;
LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg;
}
return global;
}
@@ -173,17 +168,15 @@ class CUDAModuleNode : public runtime::ModuleNode {
// internal mutex when updating the module
std::mutex mutex_;
std::array<CUfunction, kMaxNumGPUs> func_;
const int MAX_REGISTER_PER_THREAD_BLOCK = 65536;
};

// a wrapped function class to get packed func.
class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(CUDAModuleNode* m,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_void_args,
std::vector<size_t> arg_size,
void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_void_args, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
@@ -194,65 +187,49 @@ class CUDAWrappedFunc {
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args,
TVMRetValue* rv,
void** void_args) const {
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
fcache_[device_id] = m_->GetFunc(device_id, func_name_, wl);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
CUresult result;

#ifdef USE_KC_AIR
size_t raw_size = num_void_args_;
void** raw_args = new (std::nothrow) void*[raw_size];
void** raw_args = new (std::nothrow) void*[raw_size];
if (*raw_args == nullptr) {
LOG(FATAL) << "Memory alloc fail.";
LOG(FATAL) << "Memory alloc fail.";
}
size_t args_size = 0;
for (size_t i = 0; i < raw_size; ++i)
{
for (size_t i = 0; i < raw_size; ++i) {
args_size += arg_size_[i];
void** ptr = reinterpret_cast<void**>(void_args[i]);
raw_args[i] = *ptr;
}
result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2),
wl.block_dim(0),
wl.block_dim(1),
wl.block_dim(2),
(static_cast<uint32_t>(args_size)/sizeof(void *)), strm, raw_args, 0);
result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2),
wl.block_dim(0), wl.block_dim(1), wl.block_dim(2),
(static_cast<uint32_t>(args_size) / sizeof(void*)), strm, raw_args, 0);
if (raw_args != NULL) {
free(raw_args);
raw_args = NULL;
}
#else
result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2),
wl.block_dim(0),
wl.block_dim(1),
wl.block_dim(2),
0, strm, void_args, 0);
result =
cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2),
wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), 0, strm, void_args, 0);
#endif

if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char *msg;
const char* msg;
cuGetErrorName(result, &msg);
std::ostringstream os;
os << "CUDALaunch Error: " << msg << "\n"
<< " grid=(" << wl.grid_dim(0) << ","
<< wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
<< " block=(" << wl.block_dim(0) << ","
<< wl.block_dim(1) << "," << wl.block_dim(2) << ")\n";
<< " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), "
<< " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2)
<< ")\n";
std::string cuda = m_->GetSource("");
if (cuda.length() != 0) {
os << "// func_name=" << func_name_ << "\n"
@@ -283,9 +260,7 @@ class CUDAWrappedFunc {

class CUDAPrepGlobalBarrier {
public:
CUDAPrepGlobalBarrier(CUDAModuleNode* m,
ObjectPtr<Object> sptr)
: m_(m), sptr_(sptr) {
CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr<Object> sptr) : m_(m), sptr_(sptr) {
std::fill(pcache_.begin(), pcache_.end(), 0);
}

@@ -293,8 +268,8 @@ class CUDAPrepGlobalBarrier {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
if (pcache_[device_id] == 0) {
pcache_[device_id] = m_->GetGlobal(
device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
pcache_[device_id] =
m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
}
CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
}
@@ -308,12 +283,10 @@ class CUDAPrepGlobalBarrier {
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};

PackedFunc CUDAModuleNode::GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
PackedFunc CUDAModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
if (name == symbol::tvm_prepare_global_barrier) {
return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
@@ -322,7 +295,7 @@ PackedFunc CUDAModuleNode::GetFunction(
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
std::vector<size_t> arg_size(info.arg_types.size());
for (int i=0; i<static_cast<int>(info.arg_types.size()); ++i){
for (int i = 0; i < static_cast<int>(info.arg_types.size()); ++i) {
TVMType t = info.arg_types[i];
CHECK_EQ(t.lanes, 1U);
uint32_t bits = t.bits;
@@ -333,18 +306,15 @@ PackedFunc CUDAModuleNode::GetFunction(
return PackFuncVoidAddr(f, info.arg_types);
}

Module CUDAModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source) {
Module CUDAModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source) {
auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return Module(n);
}

// Load module from module.
Module CUDAModuleLoadFile(const std::string& file_name,
const std::string& format) {
Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
@@ -365,14 +335,10 @@ Module CUDAModuleLoadBinary(void* strm) {
return CUDAModuleCreate(data, fmt, fmap, std::string());
}

TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile);

TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile);

TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
.set_body_typed(CUDAModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary);
} // namespace runtime
} // namespace air


Loading…
Cancel
Save