Merge pull request !852 from lianliguang/mastertags/v0.3.0-alpha
| @@ -342,7 +342,7 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB | |||||
| std::vector<int> *node_mix_precision_datatype_index) { | std::vector<int> *node_mix_precision_datatype_index) { | ||||
| MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); | ||||
| bool add_node_datatype_flag = false; | bool add_node_datatype_flag = false; | ||||
| if (node_mix_precision_datatype->size() == 0) { | |||||
| if (node_mix_precision_datatype->empty()) { | |||||
| add_node_datatype_flag = true; | add_node_datatype_flag = true; | ||||
| } | } | ||||
| for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { | ||||
| @@ -464,8 +464,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| int SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| int status = kStatusAllMatched; | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| bool precision_reduce = false; | bool precision_reduce = false; | ||||
| std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr; | ||||
| @@ -486,11 +487,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| << "] cannot find valid kernel info, not supported the type" << buffer.str(); | << "] cannot find valid kernel info, not supported the type" << buffer.str(); | ||||
| } else { | } else { | ||||
| PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); | PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); | ||||
| status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; | |||||
| } | } | ||||
| } | } | ||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); | ||||
| // Set format and data type for input tensor. | // Set format and data type for input tensor. | ||||
| SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | SetTensorDeviceInfo(*selected_kernel_info, kernel_node); | ||||
| return status; | |||||
| } | } | ||||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, | ||||
| @@ -21,7 +21,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| void SelectKernelInfo(const CNodePtr &kernel_node); | |||||
| int SelectKernelInfo(const CNodePtr &kernel_node); | |||||
| bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info); | bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info); | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| @@ -325,10 +325,25 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr | |||||
| // compile graph steps | // compile graph steps | ||||
| void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| size_t raise_precision_count = 0; | |||||
| size_t reduce_precision_count = 0; | |||||
| for (const auto &cnode : kernel_graph.execution_order()) { | for (const auto &cnode : kernel_graph.execution_order()) { | ||||
| device::ascend::SelectKernelInfo(cnode); | |||||
| auto status = device::ascend::SelectKernelInfo(cnode); | |||||
| if (status == kStatusRaisePrecision) { | |||||
| raise_precision_count++; | |||||
| } else if (status == kStatusReducePrecision) { | |||||
| reduce_precision_count++; | |||||
| } | |||||
| MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | ||||
| } | } | ||||
| if (raise_precision_count > 0) { | |||||
| MS_LOG(WARNING) << "There has " << raise_precision_count | |||||
| << " node/nodes used raise precision to selected the kernel!"; | |||||
| } | |||||
| if (reduce_precision_count > 0) { | |||||
| MS_LOG(WARNING) << "There has " << reduce_precision_count | |||||
| << " node/nodes used reduce precision to selected the kernel!"; | |||||
| } | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| @@ -186,7 +186,10 @@ constexpr auto kControlDependBehindIndex = 2; | |||||
| // index define of depend | // index define of depend | ||||
| constexpr auto kRealInputIndexInDepend = 1; | constexpr auto kRealInputIndexInDepend = 1; | ||||
| constexpr auto kDependAttachNodeIndex = 2; | constexpr auto kDependAttachNodeIndex = 2; | ||||
| // status of kernel select result | |||||
| const int kStatusReducePrecision = -1; | |||||
| const int kStatusRaisePrecision = 1; | |||||
| const int kStatusAllMatched = 0; | |||||
| // format | // format | ||||
| constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | ||||
| constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | ||||