| @@ -35,6 +35,7 @@ enum MatchCountPriority : int { | |||||
| MATCH_COUNT_PRIORITY_BEGIN = 0, | MATCH_COUNT_PRIORITY_BEGIN = 0, | ||||
| MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, | ||||
| MATCH_FORMAT_COUNT, | MATCH_FORMAT_COUNT, | ||||
| MATCH_SPECIAL_FORMAT_COUNT, | |||||
| MATCH_5D_FORMAT_COUNT, | MATCH_5D_FORMAT_COUNT, | ||||
| MATCH_OUTPUT_DTYPE_COUNT, | MATCH_OUTPUT_DTYPE_COUNT, | ||||
| MATCH_COUNT_PRIORITY_END | MATCH_COUNT_PRIORITY_END | ||||
| @@ -81,6 +82,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: | |||||
| } | } | ||||
| return true; | return true; | ||||
| }; | }; | ||||
| if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" || | |||||
| AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") { | |||||
| if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | ||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | ||||
| AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); | ||||
| @@ -154,7 +161,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| return false; | |||||
| return true; | |||||
| } | } | ||||
| void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node, | ||||
| @@ -174,12 +181,11 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| if (input_anf_node->isa<ValueNode>()) { | |||||
| if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { | ||||
| if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) && | |||||
| kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++; | |||||
| } | |||||
| (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++; | ||||
| } | } | ||||
| if (kernel_build_info.GetInputDeviceType(input_index) == | if (kernel_build_info.GetInputDeviceType(input_index) == | ||||
| @@ -203,7 +209,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++; | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } // namespace | |||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| @@ -195,6 +195,9 @@ const std::set<std::string> kOptOperatorSet = { | |||||
| kApplyRMSPropOpName, | kApplyRMSPropOpName, | ||||
| }; | }; | ||||
| const std::set<std::string> kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||||
| static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { | static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { | ||||
| if (access(file_name.c_str(), F_OK) != 0) { | if (access(file_name.c_str(), F_OK) != 0) { | ||||
| MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; | MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; | ||||
| @@ -32,10 +32,10 @@ from mindspore.ops.op_info_register import op_info_register | |||||
| { | { | ||||
| "index": 0, | "index": 0, | ||||
| "dtype": [ | "dtype": [ | ||||
| "float16","float","float16","float16","float16","float16","float","float","float","float" | |||||
| "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" | |||||
| ], | ], | ||||
| "format": [ | "format": [ | ||||
| "FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||||
| "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||||
| ], | ], | ||||
| "name": "x", | "name": "x", | ||||
| "need_compile": false, | "need_compile": false, | ||||
| @@ -47,10 +47,10 @@ from mindspore.ops.op_info_register import op_info_register | |||||
| { | { | ||||
| "index": 0, | "index": 0, | ||||
| "dtype": [ | "dtype": [ | ||||
| "float16","float","float16","float16","float16","float16","float","float","float","float" | |||||
| "float16","float","float16","float","float16","float16","float16","float16","float","float","float","float" | |||||
| ], | ], | ||||
| "format": [ | "format": [ | ||||
| "FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||||
| "FRACTAL_NZ","FRACTAL_NZ","FracZ","FracZ","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat","DefaultFormat","NC1HWC0","DefaultFormat","DefaultFormat" | |||||
| ], | ], | ||||
| "name": "y", | "name": "y", | ||||
| "need_compile": true, | "need_compile": true, | ||||
| @@ -153,8 +153,7 @@ def test_bert_tdt(): | |||||
| batch_size = int(os.getenv('BATCH_SIZE', '16')) | batch_size = int(os.getenv('BATCH_SIZE', '16')) | ||||
| config = get_config(version=version, batch_size=batch_size) | config = get_config(version=version, batch_size=batch_size) | ||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=10000, start_learning_rate=1e-4, | |||||
| end_learning_rate=0.0, power=10.0, warmup_steps=0, decay_filter=lambda x: False) | |||||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) | |||||
| netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | ||||
| netwithgrads.set_train(True) | netwithgrads.set_train(True) | ||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| @@ -178,10 +177,10 @@ def test_bert_tdt(): | |||||
| param.default_input = weight_variable(value.asnumpy().shape) | param.default_input = weight_variable(value.asnumpy().shape) | ||||
| model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) | model.train(ds.get_repeat_count(), ds, callbacks=parallel_callback, dataset_sink_mode=False) | ||||
| loss_value = np.array(parallel_callback.loss_list) | loss_value = np.array(parallel_callback.loss_list) | ||||
| expect_out = [12.191790, 11.739655, 11.523477, 11.320723, 11.113152, 11.203759, 10.841681, 10.826849, | |||||
| 10.616718, 10.486609] | |||||
| expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, | |||||
| 12.824818, 12.38842, 12.604046] | |||||
| logger.info("expected loss value output: {}".format(expect_out)) | logger.info("expected loss value output: {}".format(expect_out)) | ||||
| assert allclose(loss_value, expect_out, 0.001, 0.001) | |||||
| assert allclose(loss_value, expect_out, 0.00001, 0.00001) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_bert_tdt() | test_bert_tdt() | ||||