| @@ -18,7 +18,7 @@ import os | |||
| import sys | |||
| from te.platform.cce_conf import te_set_version | |||
| from te.platform.fusion_util import fusion_op | |||
| import te | |||
| import tbe.common.context.op_info as operator_info | |||
| sys.path.append(os.path.abspath(os.path.dirname(__file__))) | |||
| # pylint: disable=wrong-import-position | |||
| from tbe_common import check_kernel_info, get_args, get_built_in_impl_path | |||
| @@ -68,6 +68,7 @@ def build_op(build_type, json_str, tune_mode=None): | |||
| check_kernel_info(kernel_info) | |||
| te_set_version(kernel_info["op_info"]["socVersion"]) | |||
| op_name = kernel_info['op_info']['name'] | |||
| op_type = kernel_info['op_info']['Type'] | |||
| try: | |||
| custom_flag = False | |||
| @@ -117,10 +118,13 @@ def build_op(build_type, json_str, tune_mode=None): | |||
| # with te.op.dynamic(): | |||
| import tbe.common.context.op_context as op_context | |||
| with op_context.OpContext("dynamic"): | |||
| op_info = operator_info.OpInfo(op_type, op_type) | |||
| op_context.get_context().add_op_info(op_info) | |||
| op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| compile_info = op_context.get_context().get_compile_info() | |||
| if tune_mode is not None: | |||
| return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name | |||
| return te.op.get_compile_info() | |||
| return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name | |||
| return compile_info | |||
| else: | |||
| res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| if tune_mode is not None: | |||
| @@ -113,19 +113,12 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr); | |||
| // Get para_size from json | |||
| auto kernel_json_info = kernel_pack_->kernel_json_info(); | |||
| auto op_para_size = kernel_json_info.op_para_size; | |||
| // Get stub_function | |||
| uint32_t block_dim = 1; // default blockdim equal to 1. | |||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); | |||
| if (func_stub == 0) { | |||
| MS_LOG(EXCEPTION) << "GenFuncStub failed."; | |||
| } | |||
| const void *stub_func_ptr = reinterpret_cast<void *>(func_stub); | |||
| // Generate args | |||
| std::vector<void *> runtime_args; | |||
| (void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args), | |||
| @@ -146,8 +139,26 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt | |||
| runtime_args.push_back(tiling_data_ptr); | |||
| } | |||
| auto executor = std::make_shared<device::ascend::AiCoreDynamicKernel>( | |||
| stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args); | |||
| // Get stub_function | |||
| uint32_t block_dim = 1; // default blockdim equal to 1. | |||
| device::DynamicKernelPtr executor = nullptr; | |||
| std::string origin_key; | |||
| void *handle = nullptr; | |||
| auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim, dynamic_flag, &handle, &origin_key); | |||
| if (dynamic_flag) { | |||
| if (func_stub != 1) { | |||
| MS_LOG(EXCEPTION) << "GenFuncStub failed."; | |||
| } | |||
| executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(handle, block_dim, tiling_data_ptr, op_para_size, | |||
| stream_ptr, cnode_ptr, runtime_args, origin_key); | |||
| } else { | |||
| if (func_stub == 0) { | |||
| MS_LOG(EXCEPTION) << "GenFuncStub failed."; | |||
| } | |||
| const void *stub_func_ptr = reinterpret_cast<void *>(func_stub); | |||
| executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(stub_func_ptr, block_dim, tiling_data_ptr, | |||
| op_para_size, stream_ptr, cnode_ptr, runtime_args); | |||
| } | |||
| return executor; | |||
| } | |||
| @@ -116,8 +116,8 @@ KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::s | |||
| return SearchCache(kernel_name, processor); | |||
| } | |||
| int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, | |||
| const string &magic) { | |||
| int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, const string &magic, | |||
| const bool dynamic_flag) { | |||
| static std::map<string, uint32_t> magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, | |||
| {"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN}, | |||
| {"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU}, | |||
| @@ -132,8 +132,9 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf | |||
| } | |||
| dev_bin.magic = iter->second; | |||
| dev_bin.length = kernel_buffer.len; | |||
| dev_bin.version = 2; | |||
| if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) { | |||
| dev_bin.version = 0; | |||
| auto ret = dynamic_flag ? rtRegisterAllKernel(&dev_bin, module) : rtDevBinaryRegister(&dev_bin, module); | |||
| if (RT_ERROR_NONE != ret) { | |||
| MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error."; | |||
| return -1; | |||
| } | |||
| @@ -141,7 +142,8 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf | |||
| } | |||
| uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload, | |||
| uint32_t *block_dim) { | |||
| uint32_t *block_dim, const bool dynamic_flag, void **handle, | |||
| std::string *origin_key) { | |||
| auto kernel = kernel_pack.GetKernel(); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr."; | |||
| @@ -162,14 +164,24 @@ uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel | |||
| if (iter != info_table_.end()) { | |||
| auto kernelmeta = iter->second; | |||
| *block_dim = kernelmeta->block_dim_; | |||
| return kernelmeta->func_stub_; | |||
| if (!dynamic_flag) { | |||
| return kernelmeta->func_stub_; | |||
| } | |||
| } | |||
| } | |||
| void *module = nullptr; | |||
| if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) { | |||
| if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic, dynamic_flag) != 0) { | |||
| MS_LOG(INFO) << "Call runtime BinaryRegister error."; | |||
| if (module != nullptr) { | |||
| (void)rtDevBinaryUnRegister(module); | |||
| } | |||
| return 0; | |||
| } | |||
| if (dynamic_flag) { | |||
| *handle = module; | |||
| *origin_key = func_name; | |||
| return 1; | |||
| } | |||
| // to diff different funcs. | |||
| uintptr_t func_stub = ++kernel_stub_gen_; | |||
| if (RT_ERROR_NONE != | |||
| @@ -61,13 +61,16 @@ using KernelMetaPtr = std::shared_ptr<KernelMetaInfo>; | |||
| class KernelManager { | |||
| public: | |||
| static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim); | |||
| static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim, | |||
| const bool dynamic_flag = false, void **handle = nullptr, | |||
| std::string *origin_key = nullptr); | |||
| static std::string GetStubFuncName(const KernelPackPtr &kernel_pack); | |||
| private: | |||
| KernelManager() = default; | |||
| ~KernelManager() = default; | |||
| static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic); | |||
| static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic, | |||
| const bool dynamic_flag); | |||
| static std::unordered_map<string, KernelMetaPtr> info_table_; | |||
| static uintptr_t kernel_stub_gen_; | |||
| }; | |||
| @@ -45,11 +45,23 @@ void AiCoreDynamicKernel::Execute() { | |||
| } | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Start Execute node:" << cnode->fullname_with_scope(); | |||
| auto node_info = cnode->fullname_with_scope(); | |||
| MS_LOG(INFO) << "Start Execute node:" << node_info; | |||
| rtL2Ctrl_t *l2ctrl = nullptr; | |||
| auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size()); | |||
| if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error."; | |||
| if (handle_ != nullptr) { | |||
| const auto dev_func = | |||
| origin_key_.find("kernel0") != origin_key_.npos ? origin_key_ : origin_key_ + "_" + std::to_string(tiling_key_); | |||
| const auto kernel_info = node_info + "/" + std::to_string(tiling_key_); | |||
| if (RT_ERROR_NONE != rtKernelLaunchWithHandle(handle_, dev_func.c_str(), block_dim_, runtime_args_.data(), | |||
| args_size, l2ctrl, stream_, kernel_info.c_str())) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunchWithHandle error."; | |||
| } | |||
| } else { | |||
| if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error."; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "End Execute node:" << cnode->fullname_with_scope(); | |||
| } | |||
| @@ -127,6 +139,7 @@ void AiCoreDynamicKernel::ComputeTiling() { | |||
| block_dim_ = op_run_info.block_dim; | |||
| workspaces_size_ = op_run_info.workspaces; | |||
| tiling_data_ = op_run_info.tiling_data.str(); | |||
| tiling_key_ = op_run_info.tiling_key; | |||
| } | |||
| void AiCoreDynamicKernel::AllocateWorkspace() { | |||
| @@ -40,6 +40,15 @@ class AiCoreDynamicKernel : public DynamicKernel { | |||
| tiling_data_ptr_(tiling_data_ptr), | |||
| op_para_size_(op_para_size), | |||
| runtime_args_(runtime_args) {} | |||
| AiCoreDynamicKernel(void *handle, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size, void *stream, | |||
| const CNodePtr &cnode_ptr, const std::vector<void *> &runtime_args, const std::string &ori_key) | |||
| : DynamicKernel(stream, cnode_ptr), | |||
| handle_(handle), | |||
| block_dim_(block_dim), | |||
| tiling_data_ptr_(tiling_data_ptr), | |||
| op_para_size_(op_para_size), | |||
| runtime_args_(runtime_args), | |||
| origin_key_(ori_key) {} | |||
| ~AiCoreDynamicKernel() override; | |||
| void Execute() override; | |||
| @@ -53,6 +62,7 @@ class AiCoreDynamicKernel : public DynamicKernel { | |||
| private: | |||
| const void *stub_func_; | |||
| void *handle_{nullptr}; | |||
| uint32_t block_dim_; | |||
| void *tiling_data_ptr_; // device ptr | |||
| uint32_t op_para_size_; // size of tiling_data_ptr_ | |||
| @@ -62,6 +72,8 @@ class AiCoreDynamicKernel : public DynamicKernel { | |||
| std::vector<DeviceAddressPtr> workspace_addr_; | |||
| std::shared_ptr<nlohmann::json> compile_info_json_; | |||
| optiling::OpCompileInfo op_compile_info_{}; | |||
| uint32_t tiling_key_{0}; | |||
| const std::string origin_key_{""}; | |||
| void ComputeTiling(); | |||
| bool CopyTilingToDevice(); | |||
| @@ -58,7 +58,7 @@ def test_unique_ascend(): | |||
| assert (output[1].asnumpy() == expect2).all() | |||
| @pytest.mark.level2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @@ -36,7 +36,7 @@ class NetWithEmbeddingLookUp(nn.Cell): | |||
| return out | |||
| @pytest.mark.level2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -56,7 +56,7 @@ def test_ftrl_net(): | |||
| [[0.6821311, 0.6821311]], | |||
| [[0.6821311, 0.6821311]]])) | |||
| @pytest.mark.level2 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @@ -161,6 +161,10 @@ RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFa | |||
| return RT_ERROR_NONE; | |||
| } | |||
| RTS_API rtError_t rtRegisterAllKernel(const rtDevBinary_t *bin, void **module) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; } | |||
| RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) { | |||
| return RT_ERROR_NONE; | |||
| } | |||