From: @hwhewei Reviewed-by: @zhunaipan,@zh_qh Signed-off-by: @zh_qhtags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| cmake_minimum_required(VERSION 3.14.1) | |||
| cmake_minimum_required(VERSION 3.14.0) | |||
| project(MindSpore) | |||
| if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) | |||
| @@ -14,18 +14,25 @@ endif() | |||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") | |||
| set(CMAKE_OSX_SYSROOT "") | |||
| set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") | |||
| set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings \ | |||
| -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \ | |||
| -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \ | |||
| -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") | |||
| else() | |||
| set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") | |||
| set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \ | |||
| -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") | |||
| endif() | |||
| if(ENABLE_PYTHON) | |||
| add_compile_definitions(ENABLE_PYTHON) | |||
| endif() | |||
| set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") | |||
| set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \ | |||
| -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \ | |||
| -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 \ | |||
| -Werror -Wall -Wno-deprecated-declarations -fPIC") | |||
| set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | |||
| set(PYBIND11_CPP_STANDARD -std=c++17) | |||
| @@ -132,6 +132,16 @@ def Depend(value, expr): | |||
| return value | |||
| def UpdateState(monad, expr): | |||
| """Implement `UpdateState`.""" | |||
| return monad | |||
| def Load(value, u=None): | |||
| """Implement `Load`.""" | |||
| return value | |||
| # only used in PyNative mode | |||
| def make_ref(key, value, ref): | |||
| return value | |||
| @@ -42,14 +42,16 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr< | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) { | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); | |||
| } | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| @@ -139,9 +139,9 @@ bool CheckCache(const std::string &kernel_name) { | |||
| std::string kernel_json = bin_map->Search(kernel_name); | |||
| bool ret = (!kernel_json.empty()); | |||
| if (ret) { | |||
| MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; | |||
| MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered."; | |||
| } else { | |||
| MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; | |||
| MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered."; | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -730,30 +730,6 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann: | |||
| return false; | |||
| } | |||
| void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node_list); | |||
| auto output = func_graph->output(); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::IsRealKernel(output)) { | |||
| // single output. | |||
| node_list->push_back(std::make_pair(output, 0)); | |||
| return; | |||
| } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||
| auto output_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| // multi output. | |||
| auto &inputs = output_cnode->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); | |||
| node_list->push_back(in_with_idx); | |||
| } | |||
| return; | |||
| } | |||
| MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) | |||
| << " of graph: " << func_graph->ToString(); | |||
| } | |||
| bool IsWeightBoundary(const AnfNodePtr &node) { | |||
| if (node->isa<ValueNode>()) { | |||
| return true; | |||
| @@ -776,7 +752,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto axis_attr = primitive->GetAttr(kAxis); | |||
| if (axis_attr == nullptr) { | |||
| MS_LOG(ERROR) << "This node does't have axie attr."; | |||
| MS_LOG(ERROR) << "This node doesn't have axie attr."; | |||
| return std::vector<int64_t>(); | |||
| } | |||
| std::vector<int64_t> axis_list; | |||
| @@ -181,7 +181,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| std::vector<size_t> out_shape; | |||
| out_shape.emplace_back(miss_count); | |||
| std::vector<TypeId> dtypes; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node_); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, | |||
| @@ -69,7 +69,8 @@ void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| std::vector<size_t> out_shape; | |||
| out_shape.emplace_back(count); | |||
| std::vector<TypeId> dtypes; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node_); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape, out_shape}, node_.get()); | |||
| @@ -29,5 +29,8 @@ MS_REG_GPU_KERNEL_ONE( | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| AssignGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| AssignGpuKernel, int64_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -63,13 +63,15 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| for (const auto &type : kHcclSupportTypes) { | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); | |||
| inputs_type.push_back(type); | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) { | |||
| outputs_format.emplace_back(GetKernelFormat(kernel_node, 0)); | |||
| } else { | |||
| @@ -31,7 +31,8 @@ bool IsPyNativeMode() { | |||
| bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); | |||
| hccl_kernel_intput_shape_list->emplace_back(shape_i); | |||
| } | |||
| @@ -42,7 +43,8 @@ bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<siz | |||
| bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); | |||
| hccl_kernel_output_shape_list->emplace_back(shape_i); | |||
| } | |||
| @@ -53,11 +55,12 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si | |||
| bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(data_type_list); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); | |||
| auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); | |||
| if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | |||
| MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; | |||
| MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr; | |||
| } | |||
| data_type_list->emplace_back(iter->second); | |||
| } | |||
| @@ -37,13 +37,15 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_format.emplace_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| @@ -30,7 +30,7 @@ std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { | |||
| std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { | |||
| if (output_index >= outputs_format_.size()) { | |||
| MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; | |||
| MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; | |||
| return kInvalidFormat; | |||
| } | |||
| return outputs_format_[output_index]; | |||
| @@ -86,6 +86,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel | |||
| builder.SetProcessor(AICORE); | |||
| builder.SetKernelType(RT_KERNEL); | |||
| builder.SetFusionType(OPAQUE); | |||
| // LabelSwitch always return UMonad. | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad}); | |||
| label_switch_build_info.emplace_back(builder.Build()); | |||
| } | |||
| return label_switch_build_info; | |||
| @@ -74,11 +74,10 @@ void GetRtKelInfo(const CNodePtr &kernel_node, | |||
| input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); | |||
| } | |||
| kernel_build_info_builder->SetInputsDeviceType(input_types); | |||
| // set output info | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT)); | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown)); | |||
| // set ohter info | |||
| // Kernel ops in while-list such as 'LabelSet' always return UMonad. | |||
| kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad}); | |||
| // set other info | |||
| kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); | |||
| kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); | |||
| kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); | |||
| @@ -1052,10 +1052,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { | |||
| auto node_inputs_size = cnode->inputs().size(); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (HasAbstractMonad(input)) { | |||
| node_inputs_size--; | |||
| } | |||
| } | |||
| if (op_info->inputs_ptr().size() < (node_inputs_size - 1)) { | |||
| MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); | |||
| } | |||
| return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); | |||
| return (op_info->inputs_ptr().size() + 1 - node_inputs_size); | |||
| } | |||
| std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { | |||
| @@ -1103,6 +1109,9 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, | |||
| bool is_dynamic_input = IsDynamicInput(cnode); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| auto input = cnode->input(i); | |||
| if (HasAbstractMonad(input)) { | |||
| continue; | |||
| } | |||
| auto kernel_idx = AnfAlgo::VisitKernel(input, 0); | |||
| auto real_node = kernel_idx.first; | |||
| size_t real_idx = kernel_idx.second; | |||
| @@ -112,6 +112,10 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||
| const KernelSelectPtr &kernel_select) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_node = AnfAlgo::GetInputNode(node, index); | |||
| if (HasAbstractMonad(input_node)) { | |||
| // No transfer for monad inputs. | |||
| return input_node; | |||
| } | |||
| auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| auto real_input = node_with_index.first; | |||
| @@ -330,8 +334,9 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads. | |||
| for (size_t input_index = 0; input_index < in_num; ++input_index) { | |||
| // Monad inputs keep unchanged from GetTransInputNodePtr(). | |||
| AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| new_inputs.push_back(input_node); | |||
| @@ -352,12 +357,18 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| size_t in_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads. | |||
| for (size_t input_index = 0; input_index < in_num; ++input_index) { | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| if (HasAbstractMonad(cur_input)) { | |||
| // No cast for monad inputs. | |||
| new_inputs.push_back(cur_input); | |||
| continue; | |||
| } | |||
| auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); | |||
| TypeId origin_type(kTypeUnknown); | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0); | |||
| auto real_input_node = kernel_with_index.first; | |||
| if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| @@ -244,7 +244,9 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, | |||
| if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), | |||
| (*buffer_fusion_infos)[fusion_id].inputs_list.end(), | |||
| in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { | |||
| (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); | |||
| if (!HasAbstractMonad(in)) { | |||
| (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) { | |||
| return real_node->isa<ValueNode>(); | |||
| } | |||
| void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node, | |||
| const std::vector<AnfNodePtr> &memcpy_async_list) { | |||
| MS_EXCEPTION_IF_NULL(control_depend); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); | |||
| make_tuple_inputs.emplace_back(hccl_node); | |||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| control_depend->set_input(IntToSize(index), make_tuple); | |||
| } | |||
| void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node, | |||
| const std::vector<AnfNodePtr> &memcpy_async_list) { | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(tuple_getitem); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager" | |||
| << " trace: " << trace::DumpSourceLines(hccl_node); | |||
| } | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||
| SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list); | |||
| } | |||
| } | |||
| } | |||
| void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list, | |||
| const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(hccl_node); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager" | |||
| << " trace: " << trace::DumpSourceLines(hccl_node); | |||
| } | |||
| // find hccl_node's output which is a control depend | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||
| SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list); | |||
| } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) { | |||
| DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list); | |||
| } | |||
| } | |||
| } | |||
| // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) | |||
| bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { | |||
| if (node_users.size() == 1) { | |||
| @@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con | |||
| void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| std::vector<AnfNodePtr> memcpy_async_list; | |||
| bool need_memcpy_async = false; | |||
| std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | |||
| for (size_t i = 1; i < hccl_node->size(); ++i) { | |||
| auto input = hccl_node->input(i); | |||
| @@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co | |||
| if (memcpy_async == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; | |||
| } | |||
| if (AnfAlgo::IsNodeDynamicShape(input)) { | |||
| if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) { | |||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); | |||
| } | |||
| new_inputs.push_back(memcpy_async); | |||
| memcpy_async_list.push_back(memcpy_async); | |||
| need_memcpy_async = true; | |||
| } else { | |||
| new_inputs.push_back(input); | |||
| } | |||
| } | |||
| if (!memcpy_async_list.empty()) { | |||
| if (need_memcpy_async) { | |||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||
| new_hccl_node->set_inputs(new_inputs); | |||
| auto manager = graph->manager(); | |||
| @@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co | |||
| MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; | |||
| (void)manager->Replace(hccl_node, new_hccl_node); | |||
| MS_LOG(DEBUG) << "end replace"; | |||
| // transer hccl op's control to the memcpy_async | |||
| TransferControl(new_hccl_node, memcpy_async_list, graph); | |||
| } | |||
| } | |||
| @@ -57,7 +57,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { | |||
| for (size_t input_idx = 0; input_idx < input_num; input_idx++) { | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); | |||
| auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); | |||
| auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); | |||
| @@ -40,7 +40,8 @@ const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, cons | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, input_index), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (!input_node->isa<CNode>()) { | |||
| @@ -77,7 +78,8 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr | |||
| MS_EXCEPTION_IF_NULL(node_info.first); | |||
| auto cast_out_node = node_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cast_out_node); | |||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); | |||
| for (size_t index = 0; index < input_num; ++index) { | |||
| if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first != | |||
| cast_node) { | |||
| continue; | |||
| @@ -162,7 +162,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_ | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| AbstractBasePtrList abstract_list; | |||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||
| // deal with ref output | |||
| if (ref_infos.count(output_index) != 0) { | |||
| @@ -37,7 +37,7 @@ const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph, | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto split_v = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(split_v); | |||
| auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3); | |||
| auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), kMatMulInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(matmul); | |||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0); | |||
| auto input_node = input_node_with_idx.first; | |||
| @@ -129,9 +129,21 @@ AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| auto mng = sub_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| std::vector<AnfNodePtr> todo; | |||
| std::vector<std::pair<AnfNodePtr, size_t>> graph_rets; | |||
| kernel::GetValidKernelNodes(sub_graph, &todo); | |||
| kernel::GetGraphRealOutput(sub_graph, &graph_rets); | |||
| auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem}); | |||
| std::vector<std::pair<AnfNodePtr, size_t>> graph_rets; | |||
| for (auto &output : outputs) { | |||
| size_t index = 0; | |||
| if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||
| ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)); | |||
| MS_EXCEPTION_IF_NULL(tuple_index_value); | |||
| if (!tuple_index_value->isa<Int64Imm>()) { | |||
| MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64"; | |||
| } | |||
| index = tuple_index_value->cast<Int64ImmPtr>()->value(); | |||
| } | |||
| graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index)); | |||
| } | |||
| for (auto &t : todo) { | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); | |||
| // process input | |||
| @@ -33,7 +33,8 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t index = 0; index < input_num; ++index) { | |||
| auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); | |||
| auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); | |||
| auto input_format = AnfAlgo::GetInputFormat(cnode, index); | |||
| @@ -28,8 +28,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| const size_t kCastInputNum = 2; | |||
| const size_t kTupleGetitemInputNum = 3; | |||
| bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, | |||
| const std::shared_ptr<kernel::KernelBuildInfo> &candidate_kernel_info) { | |||
| if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { | |||
| @@ -126,7 +124,8 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size | |||
| auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); | |||
| std::vector<Shape> shapes; | |||
| std::vector<TypeId> types; | |||
| for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t index = 0; index < output_num; ++index) { | |||
| if (cast_index == index) { | |||
| shapes.emplace_back(cast_shape); | |||
| types.emplace_back(cast_dtype); | |||
| @@ -175,7 +174,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||
| << (*alternative_kernel_info)->ToString(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); | |||
| if (node->inputs().size() < kCastInputNum) { | |||
| if (AnfAlgo::GetInputTensorNum(node) < kCastInputTensorNum) { | |||
| MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; | |||
| } | |||
| return node->input(1); | |||
| @@ -188,9 +187,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu | |||
| *prior_op = x_cnode; | |||
| // when x_node is tuple_getitem | |||
| if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { | |||
| if (x_cnode->inputs().size() < kTupleGetitemInputNum) { | |||
| MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); | |||
| } | |||
| CheckCNodeInputSize(x_cnode, kTupleGetItemInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(output_idx); | |||
| AnfNodePtr input1 = x_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(input1); | |||
| @@ -214,9 +211,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu | |||
| AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { | |||
| MS_EXCEPTION_IF_NULL(cur_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_query); | |||
| if (cur_node->inputs().size() < kCastInputNum) { | |||
| MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; | |||
| } | |||
| CheckCNodeInputSize(cur_node, kCastInputTensorNum); | |||
| AnfNodePtr x_node = cur_node->input(1); | |||
| if (IsUsedByOthers(graph, x_node)) { | |||
| return nullptr; | |||
| @@ -69,7 +69,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| CheckCNodeInputSize(cnode, kTransOpInputNum); | |||
| CheckCNodeInputSize(cnode, kTransOpInputTensorNum); | |||
| auto input_node = cnode->input(1); | |||
| if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) { | |||
| kernel_graph->ReplaceInternalOutput(node, input_node); | |||
| @@ -111,8 +111,8 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod | |||
| auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract()); | |||
| MS_EXCEPTION_IF_NULL(bn_abstract_tuple); | |||
| if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " | |||
| if (bn_abstract_tuple->elements().size() != kBnOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is " | |||
| << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); | |||
| } | |||
| std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], | |||
| @@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| const auto &bn_grad_inputs = bn_grad_node->inputs(); | |||
| if (bn_grad_inputs.size() < kBNGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size." | |||
| << " trace: " << trace::DumpSourceLines(bn_grad_node); | |||
| } | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| std::vector<AnfNodePtr> bn_update_grad_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], | |||
| bn_grad_inputs[4], bn_grad_inputs[5]}; | |||
| @@ -60,10 +57,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs); | |||
| const auto &bn_grad_inputs = bn_grad_node->inputs(); | |||
| if (bn_grad_inputs.size() < kBNGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" | |||
| << " trace: " << trace::DumpSourceLines(bn_grad_node); | |||
| } | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size" | |||
| << " trace: " << trace::DumpSourceLines(bn_grad_node); | |||
| @@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||
| if (bn_grad_inputs.size() != kBNGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" | |||
| << " trace: " << trace::DumpSourceLines(bn_grad_node); | |||
| } | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| std::vector<AnfNodePtr> bn_update_grad_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], | |||
| bn_grad_inputs[4], bn_grad_inputs[5]}; | |||
| @@ -59,10 +56,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||
| if (bn_grad_inputs.size() != kBNGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" | |||
| << " trace: " << trace::DumpSourceLines(bn_grad_node); | |||
| } | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; | |||
| } | |||
| @@ -32,8 +32,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & | |||
| std::vector<AnfNodePtr> *bn_training_reduce_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||
| if (bn_cnode->inputs().size() != kBnInputNum) { | |||
| MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); | |||
| if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString(); | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_reduce_inputs = { | |||
| @@ -64,10 +64,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod | |||
| const std::vector<AnfNodePtr> &bn_training_reduce_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||
| if (bn_cnode->inputs().size() != kBnInputNum) { | |||
| MS_LOG(EXCEPTION) << "BN node has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(bn_cnode); | |||
| } | |||
| CheckCNodeInputSize(bn_cnode, kBnInputTensorNum); | |||
| if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { | |||
| MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(bn_cnode); | |||
| @@ -102,8 +99,8 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() < kBnInputNum) { | |||
| MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; | |||
| if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs."; | |||
| return nullptr; | |||
| } | |||
| // Create BNTrainingReduce node and get outputs of BNTrainingReduce | |||
| @@ -123,8 +123,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const | |||
| bool CheckInputs(const CNodePtr &origin_node) { | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| if (origin_node->size() != kGatherV2DynInputNum + 1) { | |||
| MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputNum | |||
| if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) { | |||
| MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum | |||
| << ". CNode= " << origin_node->DebugString(); | |||
| return false; | |||
| } | |||
| @@ -28,11 +28,7 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars | |||
| std::vector<AnfNodePtr> *square_sum_all_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(lars_v2); | |||
| if (lars_v2->size() != kLarsV2InputNum) { | |||
| MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum | |||
| << " trace: " << trace::DumpSourceLines(lars_v2); | |||
| } | |||
| CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1), | |||
| lars_v2->input(2)}; | |||
| auto square_sum_all = graph->NewCNode(inputs); | |||
| @@ -55,10 +51,7 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, | |||
| MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2" | |||
| << " trace: " << trace::DumpSourceLines(lars_v2); | |||
| } | |||
| if (lars_v2->size() != kLarsV2InputNum) { | |||
| MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum | |||
| << " trace: " << trace::DumpSourceLines(lars_v2); | |||
| } | |||
| CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)), | |||
| lars_v2->input(1), | |||
| lars_v2->input(2), | |||
| @@ -91,7 +91,7 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->inputs().size() != kLayerNormGradInputNum) { | |||
| if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormGradInputTensorNum) { | |||
| return nullptr; | |||
| } | |||
| @@ -110,7 +110,7 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| CheckCNodeInputSize(cnode, 2); | |||
| CheckCNodeInputSize(cnode, 1); | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); | |||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| @@ -76,8 +76,8 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod | |||
| auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract()); | |||
| MS_EXCEPTION_IF_NULL(bn_abstract_tuple); | |||
| if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " | |||
| if (bn_abstract_tuple->elements().size() != kBnOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is " | |||
| << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); | |||
| } | |||
| bn_training_update_v3->set_abstract(bn->abstract()); | |||
| @@ -34,10 +34,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu | |||
| CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { | |||
| MS_EXCEPTION_IF_NULL(origin_cnode); | |||
| if (origin_cnode->inputs().size() < kSplitInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " | |||
| << kSplitInputNum - 1 << " trace: " << trace::DumpSourceLines(origin_cnode); | |||
| } | |||
| CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum); | |||
| return CreateSplitVNode(func_graph, origin_cnode->input(1)); | |||
| } | |||
| @@ -32,10 +32,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu | |||
| CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { | |||
| MS_EXCEPTION_IF_NULL(origin_cnode); | |||
| if (origin_cnode->inputs().size() < kSplitInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " | |||
| << kSplitInputNum - 1; | |||
| } | |||
| CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum); | |||
| return CreateSplitVNode(func_graph, origin_cnode->input(1)); | |||
| } | |||
| @@ -146,7 +146,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| AnfAlgo::CopyNodeAttrs(cnode, new_cnode); | |||
| CheckCNodeInputSize(new_cnode, kTopkInputNum); | |||
| CheckCNodeInputSize(new_cnode, kTopkInputTensorNum); | |||
| // Convert the tensor input to scalar and convert it to attr | |||
| auto input_k = new_cnode->input(kTopkIndexK + 1); | |||
| MS_EXCEPTION_IF_NULL(input_k); | |||
| @@ -31,7 +31,7 @@ const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const A | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kTransOpInputTensorNum); | |||
| if (IsFormatInvaild(node)) { | |||
| TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info())); | |||
| return DoSplit(func_graph, node); | |||
| @@ -77,8 +77,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s | |||
| bool CheckInputs(const CNodePtr &origin_node) { | |||
| MS_EXCEPTION_IF_NULL(origin_node); | |||
| if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { | |||
| MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum | |||
| if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) { | |||
| MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum | |||
| << ". CNode= " << origin_node->DebugString(); | |||
| return false; | |||
| } | |||
| @@ -62,8 +62,8 @@ bool CheckIndex(const AnfNodePtr &index_node) { | |||
| bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnorm); | |||
| if (batchnorm->size() < kBatchNormInputNum + 1) { | |||
| MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; | |||
| if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { | |||
| MS_LOG(DEBUG) << "BatchNorm's input less than " << kBnInputTensorNum; | |||
| return false; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { | |||
| @@ -87,7 +87,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto tuple_getitem = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); | |||
| CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum); | |||
| AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| if (!CheckIndex(index_node)) { | |||
| @@ -61,8 +61,8 @@ bool CheckIndex(const AnfNodePtr &index_node) { | |||
| bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(batchnormgrad); | |||
| if (batchnormgrad->size() < kBatchNormInputNum + 1) { | |||
| MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; | |||
| if (AnfAlgo::GetInputTensorNum(batchnormgrad) < kBNGradInputTensorNum) { | |||
| MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBnInputTensorNum; | |||
| return false; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { | |||
| @@ -86,7 +86,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto tuple_getitem = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); | |||
| CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum); | |||
| AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| if (!CheckIndex(index_node)) { | |||
| @@ -79,7 +79,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf | |||
| return nullptr; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(minimum); | |||
| if (minimum->inputs().size() != kMinimumInputNum) { | |||
| if (AnfAlgo::GetInputTensorNum(minimum) != kMinimumInputTensorNum) { | |||
| return nullptr; | |||
| } | |||
| @@ -30,9 +30,7 @@ const size_t kReluV2OutputNum = 2; | |||
| CNodePtr GetRelu(const CNodePtr &relu_grad) { | |||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||
| if (relu_grad->size() != kReluGradInputNum) { | |||
| MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); | |||
| } | |||
| CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum); | |||
| auto relu_anf = relu_grad->input(2); | |||
| MS_EXCEPTION_IF_NULL(relu_anf); | |||
| return relu_anf->cast<CNodePtr>(); | |||
| @@ -41,9 +39,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) { | |||
| CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(relu); | |||
| if (relu->size() != kReluInputNum) { | |||
| MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); | |||
| } | |||
| CheckCNodeInputSize(relu, kReluInputTensorNum); | |||
| auto prim = std::make_shared<Primitive>(kReluV2OpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)}; | |||
| @@ -53,32 +53,9 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect | |||
| } | |||
| } // namespace | |||
| const BaseRef FusedBatchNormFusion::DefinePattern() const { | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index1 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index2 = std::make_shared<CondVar>(IsC); | |||
| VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); | |||
| VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); | |||
| VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); | |||
| VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); | |||
| VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); | |||
| VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); | |||
| VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); | |||
| VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); | |||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||
| } | |||
| ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto iter_constant_input0 = (*equiv).find(constant_input0_var_); | |||
| if (iter_constant_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; | |||
| } | |||
| auto constant_input = utils::cast<AnfNodePtr>(iter_constant_input0->second); | |||
| auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_); | |||
| MS_EXCEPTION_IF_NULL(constant_input); | |||
| if (!constant_input->isa<ValueNode>()) { | |||
| return nullptr; | |||
| @@ -113,31 +90,15 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| // Set input to create node | |||
| auto iter_data_input0 = (*equiv).find(data_input0_var_); | |||
| if (iter_data_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched." | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_reduce_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), | |||
| utils::cast<AnfNodePtr>(iter_data_input0->second)}; | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)}; | |||
| auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_training_reduce); | |||
| bn_training_reduce->set_scope(node->scope()); | |||
| // Set abstract | |||
| auto iter_data_input1 = (*equiv).find(data_input1_var_); | |||
| if (iter_data_input1 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched." | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| auto data_input1 = utils::cast<AnfNodePtr>(iter_data_input1->second); | |||
| auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_); | |||
| MS_EXCEPTION_IF_NULL(data_input1); | |||
| auto iter_data_input2 = (*equiv).find(data_input2_var_); | |||
| if (iter_data_input2 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched." | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| auto data_input2 = utils::cast<AnfNodePtr>(iter_data_input2->second); | |||
| auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_); | |||
| MS_EXCEPTION_IF_NULL(data_input2); | |||
| AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| @@ -150,39 +111,15 @@ void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, | |||
| std::vector<AnfNodePtr> *bn_training_update_inputs) const { | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| MS_EXCEPTION_IF_NULL(bn_training_update_inputs); | |||
| auto iter_data_input0 = (*equiv).find(data_input0_var_); | |||
| if (iter_data_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; | |||
| } | |||
| auto iter_data_input1 = (*equiv).find(data_input1_var_); | |||
| if (iter_data_input1 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; | |||
| } | |||
| auto iter_data_input2 = (*equiv).find(data_input2_var_); | |||
| if (iter_data_input2 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; | |||
| } | |||
| auto iter_variable_input0 = (*equiv).find(variable_input0_var_); | |||
| if (iter_variable_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; | |||
| } | |||
| auto iter_variable_input1 = (*equiv).find(variable_input1_var_); | |||
| if (iter_variable_input1 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; | |||
| } | |||
| if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum | |||
| << ", but it is " << bn_training_reduce_outputs.size(); | |||
| } | |||
| *bn_training_update_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)), | |||
| utils::cast<AnfNodePtr>(iter_data_input0->second), | |||
| utils::cast<AnfNodePtr>(GetAnfNodeByVar(equiv, data_input0_var_)), | |||
| bn_training_reduce_outputs[0], | |||
| bn_training_reduce_outputs[1], | |||
| utils::cast<AnfNodePtr>(iter_data_input1->second), | |||
| utils::cast<AnfNodePtr>(iter_data_input2->second), | |||
| utils::cast<AnfNodePtr>(iter_variable_input0->second), | |||
| utils::cast<AnfNodePtr>(iter_variable_input1->second), | |||
| GetAnfNodeByVar(equiv, data_input1_var_), | |||
| GetAnfNodeByVar(equiv, data_input2_var_), | |||
| GetAnfNodeByVar(equiv, variable_input0_var_), | |||
| GetAnfNodeByVar(equiv, variable_input1_var_), | |||
| }; | |||
| } | |||
| @@ -197,19 +134,9 @@ void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv | |||
| MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " | |||
| << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); | |||
| } | |||
| auto iter_variable_input0 = (*equiv).find(variable_input0_var_); | |||
| if (iter_variable_input0 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched." | |||
| << " trace: " << trace::DumpSourceLines(bn); | |||
| } | |||
| auto variable_input0 = utils::cast<AnfNodePtr>(iter_variable_input0->second); | |||
| auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_); | |||
| auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_); | |||
| MS_EXCEPTION_IF_NULL(variable_input0); | |||
| auto iter_variable_input1 = (*equiv).find(variable_input1_var_); | |||
| if (iter_variable_input1 == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched." | |||
| << " trace: " << trace::DumpSourceLines(bn); | |||
| } | |||
| auto variable_input1 = utils::cast<AnfNodePtr>(iter_variable_input1->second); | |||
| MS_EXCEPTION_IF_NULL(variable_input1); | |||
| *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), | |||
| bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; | |||
| @@ -227,13 +154,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( | |||
| auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); | |||
| MS_EXCEPTION_IF_NULL(bn_training_update); | |||
| // Set abstract | |||
| auto iter_batch_norm = (*equiv).find(batch_norm_var_); | |||
| if (iter_batch_norm == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched." | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_); | |||
| AbstractBasePtrList abstract_list; | |||
| GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| @@ -249,6 +170,23 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( | |||
| return bn_training_update; | |||
| } | |||
| void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_); | |||
| MS_EXCEPTION_IF_NULL(assign_sub1); | |||
| for (const auto &node_index : manager->node_users()[assign_sub1]) { | |||
| const AnfNodePtr &output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { | |||
| (void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_)); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -271,14 +209,8 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c | |||
| << bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| // Replace old bn outputs with new outputs | |||
| auto iter_batch_norm = (*equiv).find(batch_norm_var_); | |||
| if (iter_batch_norm == (*equiv).end()) { | |||
| MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched." | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second); | |||
| std::vector<AnfNodePtr> bn_outputs; | |||
| GetBNOutput(func_graph, bn, &bn_outputs); | |||
| GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| for (const auto &output : bn_outputs) { | |||
| @@ -297,7 +229,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c | |||
| (void)manager->Replace(output, bn_training_update_outputs[index]); | |||
| } | |||
| } | |||
| return bn_training_update_outputs[0]; | |||
| (void)manager->Replace(node, bn_training_update_outputs[0]); | |||
| EliminateMonadNodes(func_graph, equiv); | |||
| return nullptr; | |||
| } | |||
| const BaseRef FusedBatchNormFusion::DefinePattern() const { | |||
| std::shared_ptr<Var> Xs = std::make_shared<SeqVar>(); | |||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index1 = std::make_shared<CondVar>(IsC); | |||
| VarPtr index2 = std::make_shared<CondVar>(IsC); | |||
| VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); | |||
| VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); | |||
| VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); | |||
| VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); | |||
| VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); | |||
| VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); | |||
| VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_}); | |||
| VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_}); | |||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||
| } | |||
| const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { | |||
| @@ -317,8 +270,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); | |||
| VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); | |||
| VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); | |||
| VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); | |||
| VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); | |||
| VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_}); | |||
| VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_}); | |||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||
| } | |||
| @@ -340,8 +293,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { | |||
| VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); | |||
| VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); | |||
| VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); | |||
| VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); | |||
| VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); | |||
| VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_}); | |||
| VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_}); | |||
| VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); | |||
| return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); | |||
| } | |||
| @@ -27,15 +27,20 @@ namespace opt { | |||
| class FusedBatchNormFusion : public PatternProcessPass { | |||
| public: | |||
| explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) | |||
| : PatternProcessPass(name, multigraph), | |||
| data_input0_var_(std::make_shared<Var>()), | |||
| data_input1_var_(std::make_shared<Var>()), | |||
| data_input2_var_(std::make_shared<Var>()), | |||
| variable_input0_var_(std::make_shared<Var>()), | |||
| variable_input1_var_(std::make_shared<Var>()), | |||
| constant_input0_var_(std::make_shared<Var>()), | |||
| constant_input1_var_(std::make_shared<Var>()), | |||
| batch_norm_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()))) {} | |||
| : PatternProcessPass(name, multigraph) { | |||
| data_input0_var_ = std::make_shared<Var>(); | |||
| data_input1_var_ = std::make_shared<Var>(); | |||
| data_input2_var_ = std::make_shared<Var>(); | |||
| variable_input0_var_ = std::make_shared<Var>(); | |||
| variable_input1_var_ = std::make_shared<Var>(); | |||
| constant_input0_var_ = std::make_shared<Var>(); | |||
| constant_input1_var_ = std::make_shared<Var>(); | |||
| batch_norm_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name())); | |||
| assign_sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name())); | |||
| assign_sub1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name())); | |||
| monad0_var_ = std::make_shared<Var>(); | |||
| monad1_var_ = std::make_shared<Var>(); | |||
| } | |||
| ~FusedBatchNormFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| @@ -50,6 +55,7 @@ class FusedBatchNormFusion : public PatternProcessPass { | |||
| AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, | |||
| const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const; | |||
| ValuePtr GetFactor(const EquivPtr &equiv) const; | |||
| void EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; | |||
| VarPtr data_input0_var_; | |||
| VarPtr data_input1_var_; | |||
| @@ -59,6 +65,10 @@ class FusedBatchNormFusion : public PatternProcessPass { | |||
| VarPtr constant_input0_var_; | |||
| VarPtr constant_input1_var_; | |||
| VarPtr batch_norm_var_; | |||
| VarPtr assign_sub0_var_; | |||
| VarPtr assign_sub1_var_; | |||
| VarPtr monad0_var_; | |||
| VarPtr monad1_var_; | |||
| }; | |||
| class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { | |||
| @@ -30,33 +30,21 @@ std::tuple<AnfNodePtr, AnfNodePtr, AnfNodePtr, AnfNodePtr> GetSharedNodes(const | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto add3 = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(add3); | |||
| if (add3->inputs().size() < kAddInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| CheckCNodeInputSize(add3, kAddInputTensorNum); | |||
| auto real_div2_anf = add3->input(1); | |||
| MS_EXCEPTION_IF_NULL(real_div2_anf); | |||
| auto real_div2 = real_div2_anf->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div2); | |||
| if (real_div2->inputs().size() < kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| CheckCNodeInputSize(real_div2, kRealDivInputTensorNum); | |||
| auto sqrt0_anf = real_div2->input(2); | |||
| MS_EXCEPTION_IF_NULL(sqrt0_anf); | |||
| auto sqrt0 = sqrt0_anf->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sqrt0); | |||
| if (sqrt0->inputs().size() < kRsqrtInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| CheckCNodeInputSize(sqrt0, kSqrtInputTensorNum); | |||
| auto add2_anf = sqrt0->input(1); | |||
| MS_EXCEPTION_IF_NULL(add2_anf); | |||
| auto add2 = add2_anf->cast<CNodePtr>(); | |||
| if (add2->inputs().size() < kAddInputNum) { | |||
| MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| CheckCNodeInputSize(add2, kAddInputTensorNum); | |||
| return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); | |||
| } | |||
| @@ -66,7 +54,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||
| return false; | |||
| } | |||
| auto add5 = node->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) { | |||
| if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add5) != kAddInputTensorNum) { | |||
| return false; | |||
| } | |||
| auto real_div4_anf = add5->input(1); | |||
| @@ -74,7 +62,8 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||
| return false; | |||
| } | |||
| auto real_div4 = real_div4_anf->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { | |||
| if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || | |||
| AnfAlgo::GetInputTensorNum(real_div4) != kRealDivInputTensorNum) { | |||
| return false; | |||
| } | |||
| auto add4_anf = real_div4->input(2); | |||
| @@ -82,7 +71,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||
| return false; | |||
| } | |||
| auto add4 = add4_anf->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) { | |||
| if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add4) != kAddInputTensorNum) { | |||
| return false; | |||
| } | |||
| auto sqrt1_anf = add4->input(1); | |||
| @@ -90,7 +79,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN | |||
| return false; | |||
| } | |||
| auto sqrt1 = sqrt1_anf->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { | |||
| if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || AnfAlgo::GetInputTensorNum(sqrt1) != kSqrtInputTensorNum) { | |||
| return false; | |||
| } | |||
| return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && | |||
| @@ -104,14 +93,8 @@ std::tuple<AnfNodePtr, AnfNodePtr> GetAdd0Add1Nodes(const AnfNodePtr &real_div0_ | |||
| auto real_div1 = real_div1_anf->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_div0); | |||
| MS_EXCEPTION_IF_NULL(real_div1); | |||
| if (real_div0->inputs().size() != kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(real_div0_anf); | |||
| } | |||
| if (real_div1->inputs().size() != kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(real_div1_anf); | |||
| } | |||
| CheckCNodeInputSize(real_div0, kRealDivInputTensorNum); | |||
| CheckCNodeInputSize(real_div1, kRealDivInputTensorNum); | |||
| return std::make_tuple(real_div0->input(1), real_div1->input(1)); | |||
| } | |||
| } // namespace | |||
| @@ -77,9 +77,9 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode | |||
| MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; | |||
| return false; | |||
| } | |||
| if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { | |||
| if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormBetaGammaBackpropInputTensorNum) { | |||
| MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " | |||
| << kLayerNormBetaGammaBackpropInputNum; | |||
| << kLayerNormBetaGammaBackpropInputTensorNum; | |||
| return false; | |||
| } | |||
| if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { | |||
| @@ -87,7 +87,8 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode | |||
| << kLayerNormBetaGammaBackpropOutputNum; | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { | |||
| MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; | |||
| return false; | |||
| @@ -148,15 +149,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f | |||
| // The cast_nodes size has been checked above. | |||
| MS_EXCEPTION_IF_NULL(cast_nodes[0]); | |||
| MS_EXCEPTION_IF_NULL(cast_nodes[1]); | |||
| if (cast_nodes[0]->inputs().size() != kCastInputNum) { | |||
| MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| CheckCNodeInputSize(cast_nodes[0], kCastInputTensorNum); | |||
| CheckCNodeInputSize(cast_nodes[1], kCastInputTensorNum); | |||
| (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); | |||
| if (cast_nodes[1]->inputs().size() != kCastInputNum) { | |||
| MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); | |||
| return nullptr; | |||
| } | |||
| @@ -31,6 +31,20 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto matmul = GetAnfNodeByVar(equiv, matmul_var_); | |||
| if (matmul == nullptr || !matmul->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Get CNode MatMul failed!" | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| // If there is a side-effect operator in the fusion, do not merge | |||
| MonadState state_matmul = GetMonadState(matmul); | |||
| MonadState state_node = GetMonadState(node, matmul); | |||
| if (!IsStateEquivalent(state_matmul, state_node)) { | |||
| return node; | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))); | |||
| inputs.emplace_back(GetAnfNodeByVar(equiv, x0_)); | |||
| @@ -41,11 +55,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A | |||
| new_node->set_scope(node->scope()); | |||
| new_node->set_abstract(node->abstract()); | |||
| auto matmul = GetAnfNodeByVar(equiv, matmul_var_); | |||
| if (matmul == nullptr || !matmul->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Get CNode MatMul failed!" | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| AnfAlgo::CopyNodeAttrs(matmul, new_node); | |||
| return new_node; | |||
| } | |||
| @@ -43,7 +43,9 @@ const BaseRef MomentumLossscaleFusion::DefinePattern() const { | |||
| VarPtr X1 = std::make_shared<Var>(); | |||
| VarPtr X2 = std::make_shared<Var>(); | |||
| VarPtr X4 = std::make_shared<Var>(); | |||
| return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); | |||
| // UpdateState node | |||
| VarPtr X5 = std::make_shared<Var>(); | |||
| return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4, X5}); | |||
| } | |||
| const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| @@ -52,14 +54,15 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| CheckCNodeInputSize(cnode, kApplyMomentumInputNum); | |||
| CheckCNodeInputSize(cnode, kApplyMomentumInputTensorNum); | |||
| AnfNodePtr mul = cnode->input(4); | |||
| MS_EXCEPTION_IF_NULL(mul); | |||
| auto mul_cnode = mul->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(mul_cnode); | |||
| CheckCNodeInputSize(mul_cnode, kMulInputNum); | |||
| CheckCNodeInputSize(mul_cnode, kMulInputTensorNum); | |||
| size_t value_node_index = 0; | |||
| for (size_t i = 1; i < kMulInputNum; ++i) { | |||
| // All real inputs include 1prim + x*TensorInput | |||
| for (size_t i = 1; i < kMulInputTensorNum + 1; ++i) { | |||
| if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { | |||
| value_node_index = i; | |||
| break; | |||
| @@ -70,12 +73,16 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph | |||
| return nullptr; | |||
| } | |||
| auto new_prim = std::make_shared<Primitive>(kFusedMulApplyMomentumOpName); | |||
| auto depend_prim = NewValueNode(prim::kPrimDepend); | |||
| auto depend = func_graph->NewCNode({depend_prim, cnode->input(5), cnode->input(6)}); // depend on monad | |||
| depend->set_abstract(cnode->input(5)->abstract()); | |||
| depend->set_scope(cnode->input(5)->scope()); | |||
| std::vector<AnfNodePtr> new_node_inputs{NewValueNode(new_prim), | |||
| cnode->input(1), | |||
| cnode->input(2), | |||
| cnode->input(3), | |||
| mul_cnode->input(kMulInputNum - value_node_index), | |||
| cnode->input(5), | |||
| mul_cnode->input(kMulInputTensorNum + 1 - value_node_index), | |||
| depend, | |||
| mul_cnode->input(value_node_index)}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| @@ -67,7 +67,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||
| return nullptr; | |||
| } | |||
| auto add = node->cast<CNodePtr>(); | |||
| if (add == nullptr || add->inputs().size() != kAddInputNum) { | |||
| if (add == nullptr || AnfAlgo::GetInputTensorNum(add) != kAddInputTensorNum) { | |||
| return nullptr; | |||
| } | |||
| CNodePtr mul = nullptr; | |||
| @@ -31,7 +31,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const | |||
| MS_EXCEPTION_IF_NULL(addn); | |||
| auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); | |||
| inputs.push_back(mul->input(kMulInputTensorNum + 1 - lossscale_input_index)); | |||
| inputs.push_back(addn->input(2)); | |||
| // scalar input should be 3rd input | |||
| inputs.push_back(mul->input(lossscale_input_index)); | |||
| @@ -60,7 +60,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode | |||
| } | |||
| auto addn = node->cast<CNodePtr>(); | |||
| if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { | |||
| if (addn == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto mul_anf = addn->input(1); | |||
| @@ -68,7 +68,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode | |||
| return nullptr; | |||
| } | |||
| auto mul = mul_anf->cast<CNodePtr>(); | |||
| if (mul == nullptr || mul->inputs().size() != kMulInputNum) { | |||
| if (mul == nullptr || AnfAlgo::GetInputTensorNum(mul) != kMulInputTensorNum) { | |||
| return nullptr; | |||
| } | |||
| if (IsUsedByOthers(graph, mul)) { | |||
| @@ -98,7 +98,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(DEBUG) << "Skip trans op"; | |||
| continue; | |||
| } | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; input_index++) { | |||
| std::vector<CNodePtr> trans_road; | |||
| bool first_flag = true; | |||
| auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); | |||
| @@ -26,8 +26,10 @@ void DoRefresh(const CNodePtr &cnode) { | |||
| if (cnode == nullptr) { | |||
| MS_LOG(EXCEPTION) << "node is nullptr"; | |||
| } | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { | |||
| auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; input_index++) { | |||
| auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first; | |||
| MS_EXCEPTION_IF_NULL(input_kernel_node); | |||
| if (input_kernel_node->isa<Parameter>()) { | |||
| std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| @@ -34,13 +34,14 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(out_reshape); | |||
| // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly | |||
| if (IsUsedByOthers(func_graph, out_reshape)) { | |||
| return nullptr; | |||
| } | |||
| auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum); | |||
| auto in_reshape = | |||
| CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(in_reshape); | |||
| if (IsUsedByOthers(func_graph, in_reshape)) { | |||
| return nullptr; | |||
| @@ -46,9 +46,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(transpose_cnode); | |||
| auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); | |||
| auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_cnode); | |||
| if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { | |||
| return nullptr; | |||
| @@ -33,10 +33,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(square); | |||
| MS_EXCEPTION_IF_NULL(sum); | |||
| if (square->inputs().size() != kSquareNodeInputNum) { | |||
| MS_LOG(EXCEPTION) << "Square node has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(square); | |||
| } | |||
| CheckCNodeInputSize(square, kSquareNodeInputTensorNum); | |||
| auto prim = std::make_shared<Primitive>(kSquareSumV1OpName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; | |||
| @@ -60,10 +57,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(square); | |||
| MS_EXCEPTION_IF_NULL(sum); | |||
| if (square->inputs().size() != kSquareNodeInputNum) { | |||
| MS_LOG(EXCEPTION) << "Square node has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(square); | |||
| } | |||
| CheckCNodeInputSize(square, kSquareNodeInputTensorNum); | |||
| auto prim = std::make_shared<Primitive>(kSquareSumV2OpName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; | |||
| @@ -84,10 +78,7 @@ std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto sum = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sum); | |||
| if (sum->inputs().size() != kSumNodeInputNum) { | |||
| MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size" | |||
| << " trace: " << trace::DumpSourceLines(sum); | |||
| } | |||
| CheckCNodeInputSize(sum, kSumNodeInputTensorNum); | |||
| auto square_anf = sum->input(1); | |||
| MS_EXCEPTION_IF_NULL(square_anf); | |||
| auto square = square_anf->cast<CNodePtr>(); | |||
| @@ -46,9 +46,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_cnode); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(transpose_cnode); | |||
| if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { | |||
| return nullptr; | |||
| @@ -33,9 +33,9 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); | |||
| auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(transdata_cnode); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kTransOpInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(transpose_cnode); | |||
| auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); | |||
| auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); | |||
| @@ -136,10 +136,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons | |||
| CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(conv2d); | |||
| if (conv2d->inputs().size() != kConvInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " | |||
| << conv2d->inputs().size() - 1; | |||
| } | |||
| CheckCNodeInputSize(conv2d, kConvInputTensorNum); | |||
| std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)), | |||
| conv2d->input(1), transpose}; | |||
| auto depth_conv = graph->NewCNode(depth_conv_inputs); | |||
| @@ -270,11 +267,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf | |||
| if (!NeedUpdate(conv2d, input_shape, output_shape)) { | |||
| return nullptr; | |||
| } | |||
| if (conv2d->inputs().size() != kConvInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " | |||
| << conv2d->inputs().size() - 1; | |||
| } | |||
| CheckCNodeInputSize(conv2d, kConvInputTensorNum); | |||
| auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true); | |||
| auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose); | |||
| SetConv2DAttrs(conv2d, depth_conv); | |||
| @@ -70,7 +70,8 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const { | |||
| VarPtr l1 = std::make_shared<Var>(); | |||
| VarPtr l2 = std::make_shared<Var>(); | |||
| VarPtr lr_power = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power}); | |||
| VarPtr u = std::make_shared<SeqVar>(); | |||
| VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u}); | |||
| return pattern; | |||
| } | |||
| @@ -84,7 +85,8 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const { | |||
| VarPtr lr = std::make_shared<Var>(); | |||
| VarPtr grad = std::make_shared<Var>(); | |||
| VarPtr momentum = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum}); | |||
| VarPtr u = std::make_shared<SeqVar>(); | |||
| VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u}); | |||
| return pattern; | |||
| } | |||
| @@ -114,7 +116,8 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const { | |||
| VarPtr rho = std::make_shared<Var>(); | |||
| VarPtr momentum = std::make_shared<Var>(); | |||
| VarPtr epsilon = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon}); | |||
| VarPtr u = std::make_shared<SeqVar>(); | |||
| VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u}); | |||
| return pattern; | |||
| } | |||
| @@ -109,12 +109,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| MS_EXCEPTION_IF_NULL(one_hot_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)), | |||
| sparse_softmax_node->input(1), one_hot_node}; | |||
| auto softmax_node = graph->NewCNode(inputs); | |||
| @@ -162,10 +157,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| MS_EXCEPTION_IF_NULL(softmax_output_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| auto axis_value = GetAxis(softmax_output_node); | |||
| auto axis_node = GetAxisNode(softmax_output_node); | |||
| @@ -200,9 +192,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft | |||
| CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(real_div_node); | |||
| if (real_div_node->size() != kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum; | |||
| } | |||
| CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum); | |||
| int64_t axis = -1; | |||
| auto axis_node = NewValueNode(axis); | |||
| @@ -230,9 +220,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no | |||
| CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(real_div_node); | |||
| if (real_div_node->size() != kRealDivInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum; | |||
| } | |||
| CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum); | |||
| int64_t axis = -1; | |||
| auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName); | |||
| std::vector<std::string> input_names = {"x"}; | |||
| @@ -257,13 +246,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| MS_EXCEPTION_IF_NULL(mul_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| if (mul_node->size() != kMulInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| CheckCNodeInputSize(mul_node, kMulInputTensorNum); | |||
| auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); | |||
| std::vector<int64_t> multiple_value; | |||
| @@ -310,12 +294,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| MS_EXCEPTION_IF_NULL(tile_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| std::vector<size_t> labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); | |||
| if (labels_shape.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "label's shape should be 1-D."; | |||
| @@ -343,9 +322,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax | |||
| CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| if (depend_node->size() != kDependInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op Depend's input not equal " << kDependInputNum; | |||
| } | |||
| CheckCNodeInputSize(depend_node, kDependInputTensorNum); | |||
| auto sparse_node = depend_node->input(index); | |||
| MS_EXCEPTION_IF_NULL(sparse_node); | |||
| return sparse_node->cast<CNodePtr>(); | |||
| @@ -353,9 +330,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) { | |||
| CNodePtr GetDependNode(const CNodePtr &mul_node) { | |||
| MS_EXCEPTION_IF_NULL(mul_node); | |||
| if (mul_node->size() != kMulInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; | |||
| } | |||
| CheckCNodeInputSize(mul_node, kMulInputTensorNum); | |||
| auto depend_node = mul_node->input(1); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| return depend_node->cast<CNodePtr>(); | |||
| @@ -413,10 +388,7 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F | |||
| auto sparse_softmax_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) && | |||
| AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) { | |||
| return nullptr; | |||
| @@ -451,17 +423,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con | |||
| auto mul_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(mul_node); | |||
| if (mul_node->size() != kMulInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; | |||
| } | |||
| CheckCNodeInputSize(mul_node, kMulInputTensorNum); | |||
| auto depend_node = GetDependNode(mul_node); | |||
| auto sparse_softmax_node = GetSparseNode(depend_node, 2); | |||
| auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); | |||
| if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| CNodePtr softmax_node; | |||
| auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); | |||
| @@ -538,10 +505,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process | |||
| auto sparse_softmax_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node); | |||
| if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) && | |||
| AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) { | |||
| return nullptr; | |||
| @@ -573,17 +538,12 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro | |||
| auto mul_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(mul_node); | |||
| if (mul_node->size() != kMulInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; | |||
| } | |||
| CheckCNodeInputSize(mul_node, kMulInputTensorNum); | |||
| auto sparse_softmax_node = mul_node->input(1); | |||
| auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad); | |||
| if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { | |||
| MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " | |||
| << kSparseSoftmaxCrossEntropyWithLogitsInputNum; | |||
| } | |||
| CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); | |||
| CNodePtr softmax_node; | |||
| auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); | |||
| @@ -124,18 +124,16 @@ CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_si | |||
| MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() != input_size) { | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; | |||
| } | |||
| CheckCNodeInputSize(cnode, input_size); | |||
| return cnode; | |||
| } | |||
| void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { | |||
| void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() != input_size) { | |||
| MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; | |||
| auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (real_input_tensor_num != input_tensor_size) { | |||
| MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num | |||
| << "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size; | |||
| } | |||
| } | |||
| @@ -149,17 +147,15 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y | |||
| const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); | |||
| auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum); | |||
| MS_EXCEPTION_IF_NULL(transop_cnode); | |||
| auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); | |||
| auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); | |||
| MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); | |||
| MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); | |||
| auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); | |||
| auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum); | |||
| auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum); | |||
| auto transed_node = prev_transop_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(transed_node); | |||
| std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, | |||
| depend_cnode->input(kDependInputNum - 1)}; | |||
| depend_cnode->input(kDependAttachNodeIndex)}; | |||
| AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_depend); | |||
| auto transed_abstract = transed_node->abstract(); | |||
| @@ -422,13 +418,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||
| } | |||
| auto output_info_list = iter->second; | |||
| for (const auto &output_info : output_info_list) { | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && | |||
| output_info.second == kDependAttachNodeIndex) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) { | |||
| continue; | |||
| } | |||
| output_node_list->push_back(output_info); | |||
| } | |||
| return output_node_list; | |||
| @@ -537,6 +533,9 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i | |||
| bool need_update = false; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) { | |||
| auto input_node = inputs[i + 1]; | |||
| if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) { | |||
| input_node = AnfAlgo::VisitKernel(input_node, 0).first; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) { | |||
| auto value_node = input_node->cast<ValueNodePtr>(); | |||
| @@ -548,7 +547,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i | |||
| primitive->set_attr(input_names_vec[i], value_node->value()); | |||
| need_update = true; | |||
| } else { | |||
| new_inputs.push_back(input_node); | |||
| new_inputs.push_back(inputs[i + 1]); | |||
| } | |||
| } | |||
| if (need_update) { | |||
| @@ -785,7 +784,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||
| // set value node initial device data type = infer data type | |||
| std::vector<TypeId> types; | |||
| for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(value_node); | |||
| for (size_t index = 0; index < output_num; ++index) { | |||
| types.push_back(kTypeUnknown); | |||
| } | |||
| kernel_build_info_builder->SetOutputsDeviceType(types); | |||
| @@ -29,36 +29,34 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kTransOpInputNum = 2; | |||
| constexpr size_t kCastInputNum = 2; | |||
| constexpr size_t kDependInputNum = 3; | |||
| constexpr size_t kReluInputNum = 2; | |||
| constexpr size_t kReluGradInputNum = 3; | |||
| constexpr size_t kAddInputNum = 3; | |||
| constexpr size_t kAddNInputNum = 3; | |||
| constexpr size_t kTupleGetitemInputNum = 3; | |||
| constexpr size_t kConvInputNum = 3; | |||
| constexpr size_t kRealDivInputNum = 3; | |||
| constexpr size_t kSqrtInputNum = 2; | |||
| constexpr size_t kMulInputNum = 3; | |||
| constexpr size_t kRsqrtInputNum = 2; | |||
| constexpr size_t kSubInputNum = 3; | |||
| constexpr size_t kAssignSubInputNum = 3; | |||
| constexpr size_t kDropoutInputNum = 2; | |||
| constexpr size_t kTransOpInputTensorNum = 1; | |||
| constexpr size_t kCastInputTensorNum = 1; | |||
| constexpr size_t kDependInputTensorNum = 2; | |||
| constexpr size_t kReluInputTensorNum = 1; | |||
| constexpr size_t kReluGradInputTensorNum = 2; | |||
| constexpr size_t kAddInputTensorNum = 2; | |||
| constexpr size_t kTupleGetItemInputTensorNum = 2; | |||
| constexpr size_t kConvInputTensorNum = 2; | |||
| constexpr size_t kRealDivInputTensorNum = 2; | |||
| constexpr size_t kSqrtInputTensorNum = 1; | |||
| constexpr size_t kMatMulInputTensorNum = 2; | |||
| constexpr size_t kMulInputTensorNum = 2; | |||
| constexpr size_t kSubInputTensorNum = 2; | |||
| constexpr size_t kAssignSubInputTensorNum = 2; | |||
| constexpr size_t kDropoutInputTensorNum = 1; | |||
| constexpr size_t kAssignInputTensorNum = 2; | |||
| constexpr size_t kConvBn1OutputNum = 3; | |||
| constexpr size_t kBn2ReluOutputNum = 4; | |||
| constexpr size_t kBnInputNum = 6; | |||
| constexpr size_t kBnInputTensorNum = 5; | |||
| constexpr size_t kBnOutputNum = 5; | |||
| constexpr size_t kBatchNormInputNum = 5; | |||
| constexpr size_t kBatchNormOutputNum = 5; | |||
| constexpr size_t kBN1OutputNum = 2; | |||
| constexpr size_t kBN2OutputNum = 3; | |||
| constexpr size_t kBN3OutputNum = 1; | |||
| constexpr size_t kBNGradInputNum = 6; | |||
| constexpr size_t kBNGradInputTensorNum = 5; | |||
| constexpr size_t kBNGradOutputNum = 3; | |||
| constexpr size_t kBNGrad1OutputNum = 3; | |||
| @@ -72,10 +70,10 @@ constexpr size_t kBNTrainingUpdateV3OutputNum = 5; | |||
| constexpr size_t kBNTrainingUpdateGradOutputNum = 2; | |||
| constexpr size_t kSingleOutputNum = 1; | |||
| constexpr size_t kSumNodeInputNum = 2; | |||
| constexpr size_t kSquareNodeInputNum = 2; | |||
| constexpr size_t kSumNodeInputTensorNum = 1; | |||
| constexpr size_t kSquareNodeInputTensorNum = 1; | |||
| constexpr size_t kSquareSumv2OutputNum = 2; | |||
| constexpr size_t kMinimumInputNum = 3; | |||
| constexpr size_t kMinimumInputTensorNum = 2; | |||
| constexpr size_t kLambNextMVWithDecayInputNum = 7; | |||
| constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; | |||
| @@ -85,26 +83,25 @@ constexpr size_t kLambNextRightOutputNum = 2; | |||
| constexpr size_t kLambUpdateWithLrV2InputNum = 8; | |||
| constexpr size_t kLambNextMVRuleInputNum = 14; | |||
| constexpr size_t kLambNextMVRuleOutputNum = 4; | |||
| constexpr size_t kBackendReshapeInputNum = 2; | |||
| constexpr size_t kBackendTransposeInputNum = 2; | |||
| constexpr size_t kBackendReshapeInputTensorNum = 1; | |||
| constexpr size_t kBackendTransposeInputTensorNum = 1; | |||
| constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; | |||
| constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; | |||
| constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4; | |||
| constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; | |||
| constexpr size_t kLayerNormGradInputNum = 6; | |||
| constexpr size_t kLayerNormGradInputTensorNum = 5; | |||
| constexpr size_t kAdamApplyOneOutputNum = 3; | |||
| constexpr size_t kBackendTransDataInputNum = 2; | |||
| constexpr size_t kApplyMomentumInputNum = 6; | |||
| constexpr size_t kBiasAddInputNum = 3; | |||
| constexpr size_t kTopkInputNum = 3; | |||
| constexpr size_t kLarsV2InputNum = 5; | |||
| constexpr size_t kApplyMomentumInputTensorNum = 5; | |||
| constexpr size_t kBiasAddInputTensorNum = 2; | |||
| constexpr size_t kTopkInputTensorNum = 2; | |||
| constexpr size_t kLarsV2InputTensorNum = 4; | |||
| constexpr size_t kFusedMulApplyMomentumOutputNum = 2; | |||
| constexpr size_t kSplitInputNum = 2; | |||
| constexpr size_t kGatherV2DynInputNum = 3; | |||
| constexpr size_t kUnsortedSegmentSumInputNum = 2; | |||
| constexpr size_t kSplitInputTensorNum = 1; | |||
| constexpr size_t kGatherV2DynInputTensorNum = 3; | |||
| constexpr size_t kUnsortedSegmentSumInputTensorNum = 2; | |||
| constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2; | |||
| constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3; | |||
| constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum = 2; | |||
| constexpr size_t kOneHotOutputNum = 1; | |||
| constexpr size_t kOneHotInputNum = 5; | |||
| constexpr size_t kOneHotInputTensorNum = 4; | |||
| enum FusedBatchNormInput { | |||
| kX = 1, | |||
| @@ -137,7 +134,7 @@ bool Visited(const BaseRef &n); | |||
| // check if the input node is CNode, then check it's input_size, return CNodePtr if check success. | |||
| CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size); | |||
| void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); | |||
| void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_num); | |||
| bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); | |||
| @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| @@ -51,19 +53,30 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| } // namespace | |||
| const BaseRef AdamFusion::DefinePattern() const { | |||
| VectorRef next_m = VectorRef( | |||
| {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||
| VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | |||
| VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | |||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||
| VectorRef next_v = | |||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | |||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | |||
| VectorRef update = | |||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | |||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); | |||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); | |||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | |||
| VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | |||
| next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_m}); | |||
| VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_v}); | |||
| return next_param; | |||
| } | |||
| @@ -81,6 +94,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr | |||
| auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]); | |||
| auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]); | |||
| auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]); | |||
| MS_EXCEPTION_IF_NULL(beta1_input); | |||
| MS_EXCEPTION_IF_NULL(one_sub_beta1_input); | |||
| MS_EXCEPTION_IF_NULL(beta2_input); | |||
| @@ -91,13 +105,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr | |||
| MS_EXCEPTION_IF_NULL(m_input); | |||
| MS_EXCEPTION_IF_NULL(v_input); | |||
| MS_EXCEPTION_IF_NULL(gradient_input); | |||
| MS_EXCEPTION_IF_NULL(u_input); | |||
| // Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators. | |||
| auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name()); | |||
| MS_EXCEPTION_IF_NULL(prim_depend); | |||
| std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input}; | |||
| auto param = graph->NewCNode(param_inputs); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| param->set_abstract(param_input->abstract()); | |||
| // Fused into a FusedAdam operator. | |||
| auto prim = std::make_shared<Primitive>(kFusedAdamName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = { | |||
| NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, | |||
| eps_input, lr_input, param_input, m_input, v_input, | |||
| gradient_input}; | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), | |||
| beta1_input, | |||
| one_sub_beta1_input, | |||
| beta2_input, | |||
| one_sub_beta2_input, | |||
| eps_input, | |||
| lr_input, | |||
| param, | |||
| m_input, | |||
| v_input, | |||
| gradient_input}; | |||
| auto adam = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(adam); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -107,6 +138,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr | |||
| auto build_info = GenerateKernelBuildInfo(adam); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdam and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| auto n = node->cast<CNodePtr>()->input(2); | |||
| auto fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam); | |||
| break; | |||
| } | |||
| } | |||
| return adam; | |||
| } | |||
| } // namespace opt | |||
| @@ -34,6 +34,8 @@ class AdamFusion : public PatternProcessPass { | |||
| m_ = std::make_shared<Var>(); | |||
| v_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| u_ = std::make_shared<Var>(); | |||
| u2_ = std::make_shared<Var>(); | |||
| } | |||
| ~AdamFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -50,6 +52,8 @@ class AdamFusion : public PatternProcessPass { | |||
| VarPtr m_; | |||
| VarPtr v_; | |||
| VarPtr gradient_; | |||
| VarPtr u_; | |||
| VarPtr u2_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| @@ -51,11 +53,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| } // namespace | |||
| const BaseRef AdamWeightDecayFusion::DefinePattern() const { | |||
| VectorRef next_m = VectorRef( | |||
| {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||
| VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | |||
| VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | |||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | |||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||
| VectorRef next_v = | |||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), | |||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | |||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | |||
| VectorRef update = | |||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | |||
| VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); | |||
| @@ -63,9 +68,16 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const { | |||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | |||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); | |||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); | |||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | |||
| VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | |||
| next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_m}); | |||
| VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state}); | |||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_v}); | |||
| return next_param; | |||
| } | |||
| @@ -85,6 +97,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const | |||
| auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]); | |||
| auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]); | |||
| auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]); | |||
| MS_EXCEPTION_IF_NULL(beta1_input); | |||
| MS_EXCEPTION_IF_NULL(one_sub_beta1_input); | |||
| MS_EXCEPTION_IF_NULL(beta2_input); | |||
| @@ -96,13 +109,31 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const | |||
| MS_EXCEPTION_IF_NULL(m_input); | |||
| MS_EXCEPTION_IF_NULL(v_input); | |||
| MS_EXCEPTION_IF_NULL(gradient_input); | |||
| MS_EXCEPTION_IF_NULL(u_input); | |||
| // Use depend(param, u) to maintain the execution order of FusedAdamWeightDecay and the previous operators. | |||
| auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name()); | |||
| MS_EXCEPTION_IF_NULL(prim_depend); | |||
| std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input}; | |||
| auto param = graph->NewCNode(param_inputs); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| param->set_abstract(param_input->abstract()); | |||
| // Fused into a FusedAdamWeightDecay operator. | |||
| auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = { | |||
| NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, | |||
| eps_input, lr_input, param_input, m_input, v_input, | |||
| gradient_input, weight_decay_input}; | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), | |||
| beta1_input, | |||
| one_sub_beta1_input, | |||
| beta2_input, | |||
| one_sub_beta2_input, | |||
| eps_input, | |||
| lr_input, | |||
| param, | |||
| m_input, | |||
| v_input, | |||
| gradient_input, | |||
| weight_decay_input}; | |||
| auto adam_weight_decay = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(adam_weight_decay); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -112,6 +143,30 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const | |||
| auto build_info = GenerateKernelBuildInfo(adam_weight_decay); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdamWeightDecay and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| auto n = node->cast<CNodePtr>()->input(2); | |||
| auto fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay); | |||
| break; | |||
| } | |||
| } | |||
| return adam_weight_decay; | |||
| } | |||
| } // namespace opt | |||
| @@ -35,6 +35,8 @@ class AdamWeightDecayFusion : public PatternProcessPass { | |||
| m_ = std::make_shared<Var>(); | |||
| v_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| u_ = std::make_shared<Var>(); | |||
| u2_ = std::make_shared<Var>(); | |||
| } | |||
| ~AdamWeightDecayFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -52,6 +54,8 @@ class AdamWeightDecayFusion : public PatternProcessPass { | |||
| VarPtr m_; | |||
| VarPtr v_; | |||
| VarPtr gradient_; | |||
| VarPtr u_; | |||
| VarPtr u2_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| @@ -78,7 +80,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo | |||
| std::vector<TypeId> types; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| types.push_back(AnfAlgo::GetOutputInferDataType(node, i)); | |||
| shapes.push_back(AnfAlgo::GetOutputInferShape(node, i)); | |||
| } | |||
| @@ -51,7 +51,7 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { | |||
| const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { | |||
| VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); | |||
| VectorRef apply_momentum = | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_}); | |||
| return apply_momentum; | |||
| } | |||
| @@ -66,17 +66,19 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co | |||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||
| auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| MS_EXCEPTION_IF_NULL(variable); | |||
| MS_EXCEPTION_IF_NULL(accumulation); | |||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(momentum); | |||
| MS_EXCEPTION_IF_NULL(monad_state); | |||
| auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation, | |||
| learning_rate, gradient, momentum}; | |||
| learning_rate, gradient, momentum, monad_state}; | |||
| auto replace_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -31,6 +31,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| momentum_ = std::make_shared<Var>(); | |||
| monad_state_ = std::make_shared<Var>(); | |||
| } | |||
| ~ApplyMomentumScaleFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -45,6 +46,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { | |||
| VarPtr learning_rate_; | |||
| VarPtr gradient_; | |||
| VarPtr momentum_; | |||
| VarPtr monad_state_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -49,10 +49,11 @@ bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) { | |||
| } | |||
| const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const { | |||
| VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_}); | |||
| VectorRef weight_decay = | |||
| VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_}); | |||
| VectorRef apply_momentum = | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_}); | |||
| VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), gradient_}); | |||
| VectorRef apply_momentum = VectorRef( | |||
| {prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_, monad_state_}); | |||
| return apply_momentum; | |||
| } | |||
| @@ -67,17 +68,19 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra | |||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||
| auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]); | |||
| MS_EXCEPTION_IF_NULL(weight_decay); | |||
| MS_EXCEPTION_IF_NULL(variable); | |||
| MS_EXCEPTION_IF_NULL(accumulation); | |||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(momentum); | |||
| MS_EXCEPTION_IF_NULL(monad_state); | |||
| auto prim = std::make_shared<Primitive>(kFusedWeightApplyMomentum); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, variable, accumulation, | |||
| learning_rate, gradient, momentum}; | |||
| learning_rate, gradient, momentum, monad_state}; | |||
| auto replace_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -25,12 +25,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumWeightDecayFusion(bool multigraph = true) | |||
| : PatternProcessPass("momentum_weightdecay_fusion", multigraph) { | |||
| monad_ = std::make_shared<Var>(); | |||
| weight_decay_ = std::make_shared<Var>(); | |||
| variable_ = std::make_shared<Var>(); | |||
| accumulation_ = std::make_shared<Var>(); | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| momentum_ = std::make_shared<Var>(); | |||
| monad_state_ = std::make_shared<Var>(); | |||
| } | |||
| ~ApplyMomentumWeightDecayFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -39,12 +41,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass { | |||
| private: | |||
| static bool IsScalar(const BaseRef &n); | |||
| VarPtr monad_; | |||
| VarPtr weight_decay_; | |||
| VarPtr variable_; | |||
| VarPtr accumulation_; | |||
| VarPtr learning_rate_; | |||
| VarPtr gradient_; | |||
| VarPtr momentum_; | |||
| VarPtr monad_state_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -49,11 +49,12 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) { | |||
| } | |||
| const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { | |||
| VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_}); | |||
| VectorRef weight = VectorRef( | |||
| {prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); | |||
| {prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); | |||
| VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_}); | |||
| VectorRef apply_momentum = | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_}); | |||
| return apply_momentum; | |||
| } | |||
| @@ -69,6 +70,8 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr | |||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||
| auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]); | |||
| MS_EXCEPTION_IF_NULL(weight_decay); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| MS_EXCEPTION_IF_NULL(variable); | |||
| @@ -76,11 +79,12 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr | |||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(momentum); | |||
| MS_EXCEPTION_IF_NULL(monad_state); | |||
| auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, | |||
| accumulation, learning_rate, gradient, momentum}; | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation, | |||
| learning_rate, gradient, momentum, monad_state}; | |||
| auto replace_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -25,6 +25,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) | |||
| : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { | |||
| monad_ = std::make_shared<Var>(); | |||
| weight_decay_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<CondVar>(IsScalar); | |||
| variable_ = std::make_shared<Var>(); | |||
| @@ -32,6 +33,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| momentum_ = std::make_shared<Var>(); | |||
| monad_state_ = std::make_shared<Var>(); | |||
| } | |||
| ~ApplyMomentumWeightDecayScaleFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| @@ -40,6 +42,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| private: | |||
| static bool IsScalar(const BaseRef &n); | |||
| VarPtr monad_; | |||
| VarPtr weight_decay_; | |||
| VarPtr scale_; | |||
| VarPtr variable_; | |||
| @@ -47,6 +50,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| VarPtr learning_rate_; | |||
| VarPtr gradient_; | |||
| VarPtr momentum_; | |||
| VarPtr monad_state_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -37,11 +37,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> | |||
| for (size_t idx = 0; idx < node_list.size(); ++idx) { | |||
| auto cnode = utils::cast<CNodePtr>(node_list[idx]); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_device_format.push_back(kOpFormat_DEFAULT); | |||
| inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_device_format.push_back(kOpFormat_DEFAULT); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||
| @@ -57,16 +59,39 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> | |||
| bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) { | |||
| std::vector<AnfNodePtr> cast_32to16_list; | |||
| std::vector<AnfNodePtr> cast_16to32_list; | |||
| AnfNodePtr cast_32to16_load_monad = nullptr; | |||
| AnfNodePtr cast_16to32_load_monad = nullptr; | |||
| constexpr size_t second_input_index = 2; | |||
| for (auto &cast_node : node_list) { | |||
| // currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle. | |||
| if (cast_node != nullptr && cast_node->isa<CNode>() && AnfAlgo::GetCNodeName(cast_node) == "Cast" && | |||
| (AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0))->isa<Parameter>()) { | |||
| auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0); | |||
| auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0); | |||
| if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) { | |||
| cast_32to16_list.push_back(cast_node); | |||
| } else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) { | |||
| cast_16to32_list.push_back(cast_node); | |||
| // { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }} | |||
| if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) { | |||
| auto input0 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0); | |||
| if (input0->isa<Parameter>() || (IsPrimitiveCNode(input0, prim::kPrimLoad) && | |||
| (AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input0), 0))->isa<Parameter>())) { | |||
| auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0); | |||
| auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0); | |||
| if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) { | |||
| cast_32to16_list.push_back(cast_node); | |||
| if (IsPrimitiveCNode(input0, prim::kPrimLoad)) { | |||
| auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index); | |||
| if (cast_32to16_load_monad == nullptr) { | |||
| cast_32to16_load_monad = monad; | |||
| } else if (cast_32to16_load_monad != monad) { | |||
| return false; | |||
| } | |||
| } | |||
| } else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) { | |||
| cast_16to32_list.push_back(cast_node); | |||
| if (IsPrimitiveCNode(input0, prim::kPrimLoad)) { | |||
| auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index); | |||
| if (cast_16to32_load_monad == nullptr) { | |||
| cast_16to32_load_monad = monad; | |||
| } else if (cast_16to32_load_monad != monad) { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -36,11 +36,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> | |||
| for (size_t idx = 0; idx < node_list.size(); ++idx) { | |||
| auto cnode = utils::cast<CNodePtr>(node_list[idx]); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_device_format.push_back(kOpFormat_DEFAULT); | |||
| inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_device_format.push_back(kOpFormat_DEFAULT); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||
| @@ -53,6 +53,8 @@ std::set<string> kSkipOpNames = { | |||
| std::map<string, uint32_t> kAggregatesOpNames = { | |||
| {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; | |||
| constexpr size_t inplace_node_size = 2; | |||
| template <typename T> | |||
| void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node); | |||
| @@ -60,40 +62,103 @@ void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { | |||
| primitive->AddAttr(key, MakeValue(value)); | |||
| } | |||
| void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) { | |||
| std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_node) { | |||
| if (inplace_node.size() != inplace_node_size) { | |||
| return {0, false}; | |||
| } | |||
| auto first_node = inplace_node[0].node; | |||
| auto second_node = inplace_node[1].node; | |||
| if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName || | |||
| AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) { | |||
| return {0, false}; | |||
| } | |||
| auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node); | |||
| auto first_node_channel = first_node_prim.get()->GetAttr("out_channel"); | |||
| MS_EXCEPTION_IF_NULL(first_node_channel); | |||
| size_t first_channel = first_node_channel->cast<Int64ImmPtr>()->value(); | |||
| auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node); | |||
| auto second_node_channel = second_node_prim.get()->GetAttr("out_channel"); | |||
| MS_EXCEPTION_IF_NULL(second_node_channel); | |||
| size_t second_channel = second_node_channel->cast<Int64ImmPtr>()->value(); | |||
| size_t cover_index = (first_channel >= second_channel) ? 0 : 1; | |||
| return {cover_index, true}; | |||
| } | |||
| void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) { | |||
| auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get()); | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(src); | |||
| std::vector<TypeId> types; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i)); | |||
| shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get()); | |||
| } | |||
| void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, const FuncGraphPtr &graph) { | |||
| if (inplace_node->size() == inplace_node_size) { | |||
| auto first_cnode = (*inplace_node)[0].node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(first_cnode); | |||
| auto first_node_input = first_cnode->input(1); | |||
| auto second_cnode = (*inplace_node)[1].node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(second_cnode); | |||
| auto second_node_input = second_cnode->input(1); | |||
| // if two inplace nodes have same input, will be have loop after insert depend | |||
| // so copy a new input for one of inplace node | |||
| if (first_node_input == second_node_input) { | |||
| auto cnode = first_node_input->cast<CNodePtr>(); | |||
| auto new_input = graph->NewCNode(cnode->inputs()); | |||
| new_input->set_abstract(first_node_input->abstract()); | |||
| CopyKernelInfo(first_node_input, new_input); | |||
| auto new_inplace_node = graph->NewCNode({first_cnode->input(0), new_input, first_cnode->input(2)}); | |||
| new_inplace_node->set_abstract(first_cnode->abstract()); | |||
| CopyKernelInfo(first_cnode, new_inplace_node); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(first_cnode, new_inplace_node); | |||
| (*inplace_node)[0].node = new_inplace_node; | |||
| } | |||
| } | |||
| } | |||
| void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node, | |||
| const FuncGraphPtr &graph) { | |||
| SetPrimAttr(aggregate_node.node, "aggregate", true); | |||
| SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index); | |||
| SetPrimAttr(skip_node, "skip", true); | |||
| static uint32_t group = 0; | |||
| auto [cover_index, order_required] = GetCoverIndex(*inplace_node); | |||
| if (order_required) { | |||
| CheckInplaceNodeInputs(inplace_node, graph); | |||
| } | |||
| for (size_t i = 0; i < inplace_node->size(); i++) { | |||
| auto algo = (i == 0) ? "cover" : "accumulation"; | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo); | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_group", group); | |||
| SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index); | |||
| auto algo = (i == cover_index) ? "cover" : "accumulation"; | |||
| auto node = (*inplace_node)[i].node; | |||
| SetPrimAttr(node, "inplace_algo", algo); | |||
| SetPrimAttr(node, "inplace_group", group); | |||
| SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index); | |||
| // for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover | |||
| if (order_required && i != cover_index) { | |||
| auto acc_node = node; | |||
| auto cover_node = (*inplace_node)[cover_index].node; | |||
| auto acc_node_input = acc_node->cast<CNodePtr>()->input(1); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), | |||
| acc_node_input, cover_node}; | |||
| auto depend_node = graph->NewCNode(inputs); | |||
| depend_node->set_abstract(acc_node_input->abstract()); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(acc_node_input, depend_node); | |||
| } | |||
| } | |||
| group++; | |||
| } | |||
| void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes, | |||
| const AnfNodePtr aggregate_node) { | |||
| std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||
| inplace_nodes[0].node, inplace_nodes[1].node}; | |||
| auto control_depend_node = graph->NewCNode(inputs1); | |||
| std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), | |||
| aggregate_node, control_depend_node}; | |||
| auto depend_node = graph->NewCNode(inputs2); | |||
| auto users = GetRealNodeUsedList(graph, aggregate_node); | |||
| if (users->size() == 0) { | |||
| MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString(); | |||
| } | |||
| auto mount_node = users->at(0).first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(mount_node); | |||
| mount_node->set_input(kFirstDataInputIndex, depend_node); | |||
| } | |||
| bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, | |||
| std::vector<AnfNodeIndex> *inplace) { | |||
| MS_EXCEPTION_IF_NULL(skip_node); | |||
| @@ -117,7 +182,8 @@ bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeInde | |||
| auto cnode = (*skip_node)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i); | |||
| if (!inplace_node->isa<CNode>()) { | |||
| return false; | |||
| @@ -187,9 +253,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { | |||
| << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString() | |||
| << std::endl; | |||
| // 2. Set Node attr | |||
| SetNodeAttr(aggregate_node, skip_node, &inplace_node); | |||
| // 3. Set dependence for inplace nodes | |||
| InsertControlDependToGraph(graph, inplace_node, aggregate_node.node); | |||
| SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph); | |||
| } | |||
| return true; | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef PostBatchNormAddReluFusion::DefinePattern() const { | |||
| VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | |||
| VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item}); | |||
| VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); | |||
| return relu; | |||
| } | |||
| const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||
| MS_EXCEPTION_IF_NULL(tensor_add); | |||
| auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1); | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm_ex); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(format_attr); | |||
| auto format = GetValue<std::string>(format_attr); | |||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| return nullptr; | |||
| } | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||
| return nullptr; | |||
| } | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); | |||
| auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| MS_EXCEPTION_IF_NULL(bias); | |||
| MS_EXCEPTION_IF_NULL(mean); | |||
| MS_EXCEPTION_IF_NULL(var); | |||
| MS_EXCEPTION_IF_NULL(z); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; | |||
| auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu); | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); | |||
| return tuple_get_item; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class PostBatchNormAddReluFusion : public PatternProcessPass { | |||
| public: | |||
| explicit PostBatchNormAddReluFusion(bool multigraph = true) | |||
| : PatternProcessPass("post_batch_norm_add_relu_fusion", multigraph) { | |||
| x_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<Var>(); | |||
| bias_ = std::make_shared<Var>(); | |||
| mean_ = std::make_shared<Var>(); | |||
| var_ = std::make_shared<Var>(); | |||
| index_ = std::make_shared<Var>(); | |||
| z_ = std::make_shared<Var>(); | |||
| } | |||
| ~PostBatchNormAddReluFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr x_; | |||
| VarPtr scale_; | |||
| VarPtr bias_; | |||
| VarPtr mean_; | |||
| VarPtr var_; | |||
| VarPtr index_; | |||
| VarPtr z_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ | |||
| @@ -32,9 +32,7 @@ const size_t kReluV2OutputNum = 2; | |||
| CNodePtr GetRelu(const CNodePtr &relu_grad) { | |||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||
| if (relu_grad->size() != kReluGradInputNum) { | |||
| MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); | |||
| } | |||
| CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum); | |||
| auto relu_anf = relu_grad->input(2); | |||
| MS_EXCEPTION_IF_NULL(relu_anf); | |||
| return relu_anf->cast<CNodePtr>(); | |||
| @@ -47,11 +45,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| } | |||
| @@ -65,9 +65,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(relu); | |||
| if (relu->size() != kReluInputNum) { | |||
| MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); | |||
| } | |||
| CheckCNodeInputSize(relu, kReluInputTensorNum); | |||
| auto prim = std::make_shared<Primitive>(kReluV2OpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)}; | |||
| @@ -106,7 +104,8 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, | |||
| std::vector<TypeId> types; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i)); | |||
| shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i)); | |||
| } | |||
| @@ -305,52 +305,14 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo | |||
| user_cnode->set_input(index, depend_cnode); | |||
| } | |||
| AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, | |||
| const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) { | |||
| // Create control depend, first input is composite op, second is user | |||
| AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; | |||
| auto control_depend_cnode = main_graph->NewCNode(cd_inputs); | |||
| main_graph->AddNode(control_depend_cnode); | |||
| // Create depend node to hold new control depend node. | |||
| AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode}; | |||
| auto depend_cnode = main_graph->NewCNode(d_inputs); | |||
| depend_cnode->set_abstract(patron_node->abstract()); | |||
| main_graph->AddNode(depend_cnode); | |||
| return depend_cnode; | |||
| } | |||
| std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) { | |||
| auto mng = main_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(main_graph, true); | |||
| main_graph->set_manager(mng); | |||
| } | |||
| AnfNodePtr patron_node; | |||
| auto return_cnode = main_graph->get_return()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(return_cnode); | |||
| auto output_node = return_cnode->input(kFirstDataInputIndex); | |||
| if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { | |||
| auto output_cnode = output_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| patron_node = output_cnode->input(kFirstDataInputIndex); | |||
| } else { | |||
| patron_node = output_node; | |||
| } | |||
| auto &user_nodes = mng->node_users()[patron_node]; | |||
| auto user = user_nodes.begin(); | |||
| return std::make_tuple(patron_node, user->first, user->second); | |||
| } | |||
| void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, | |||
| int index) { | |||
| auto patron_user_cnode = patron_user->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(patron_user_cnode); | |||
| patron_user_cnode->set_input(index, patron_node); | |||
| CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) { | |||
| // Insert update_state_node, need mount a monad node. | |||
| auto u = NewValueNode(kUMonad); | |||
| u->set_abstract(kUMonad->ToAbstract()); | |||
| AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node}; | |||
| auto update_state_cnode = main_graph->NewCNode(update_state_inputs); | |||
| main_graph->AddNode(update_state_cnode); | |||
| return update_state_cnode; | |||
| } | |||
| CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) { | |||
| @@ -474,24 +436,21 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs | |||
| } | |||
| void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { | |||
| const AnfNodePtr &broadcast_to_node, | |||
| const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) { | |||
| // 1. find users, change getitem index if needed. | |||
| std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes = | |||
| FindOriginCNodeUsers(main_graph, composite_node, mng, true); | |||
| for (const auto &[user_node, index] : reduce_user_nodes) { | |||
| // 2. set ac output as user's input. | |||
| // 3. Make sure modified composite node running first. | |||
| // * To not change the origin node's dependency relation, add ControlDepend and Depend node. | |||
| // * For Return node and output node, ControlDepend node will change the order of these two node, which will may | |||
| // main graph running failed. So only add Depend node to meet the need of execute order. | |||
| if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) { | |||
| AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index); | |||
| } else { | |||
| auto user_cnode = user_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| user_cnode->set_input(index, broadcast_to_node); | |||
| to_process_order_.emplace_back(composite_node, user_node); | |||
| } | |||
| // 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect | |||
| // update_state_node, broadcat_node and load_node to keep order. | |||
| AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node}; | |||
| auto load_node = main_graph->NewCNode(load_inputs); | |||
| main_graph->AddNode(load_node); | |||
| auto user_cnode = user_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| user_cnode->set_input(index, load_node); | |||
| to_process_order_.emplace_back(composite_node, user_node); | |||
| } | |||
| } | |||
| @@ -509,8 +468,11 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c | |||
| // Note: if it's single output, this will increase total memory because of a fake out. | |||
| ProcessOriginCNode(origin_composite_node, broadcast_to_node, mng); | |||
| // Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user. | |||
| ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng); | |||
| // Insert update_state_node to keep execution order. | |||
| auto update_state_node = InsertUpdateState(main_graph, origin_composite_node); | |||
| // Replace origin ReduceSum's user with atomic clean output | |||
| ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng); | |||
| MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope() | |||
| << ", clean node: " << broadcast_to_node->fullname_with_scope(); | |||
| } | |||
| @@ -554,14 +516,6 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| if (changed) { | |||
| if (!to_process_order_.empty()) { | |||
| auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph); | |||
| for (const auto &[prior, behind] : to_process_order_) { | |||
| patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node); | |||
| } | |||
| PostprocessForLastPatron(patron_node, patron_user, user_index); | |||
| } | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| @@ -37,9 +37,10 @@ class AtomicCleanInsertter : public Pass { | |||
| virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); | |||
| virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, | |||
| const FuncGraphManagerPtr &mng); | |||
| void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); | |||
| void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &user_node, int index); | |||
| void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); | |||
| CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node); | |||
| CNodePtr atomic_add_node_{nullptr}; | |||
| private: | |||
| @@ -48,11 +49,8 @@ class AtomicCleanInsertter : public Pass { | |||
| CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); | |||
| void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); | |||
| void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, | |||
| const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); | |||
| std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph); | |||
| AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, | |||
| const AnfNodePtr &behind_node, const AnfNodePtr &patron_node); | |||
| void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index); | |||
| const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node, | |||
| const FuncGraphManagerPtr &mng); | |||
| std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph, | |||
| const AnfNodePtr &composite_node, | |||
| const FuncGraphManagerPtr &mng, bool correct_index); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -149,9 +149,18 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| } | |||
| auto fuse_nodes = FindFuseCNodes(node, depend_prior); | |||
| if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) { | |||
| if (fuse_nodes.empty()) { | |||
| continue; | |||
| } | |||
| if (fuse_nodes.size() == 1) { | |||
| // Do not fuse a single GraphKernel again. | |||
| // Do not fuse a single Assign. | |||
| if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) { | |||
| continue; | |||
| } | |||
| } | |||
| changed = true; | |||
| fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); | |||
| AnfNodePtr fused_new_node; | |||
| @@ -109,7 +109,10 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) { | |||
| // 1. Try to remove redundant depend. | |||
| bool changed = false; | |||
| auto nodes = TopoSort(func_graph->get_return()); | |||
| std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) { | |||
| std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void { | |||
| if (HasAbstractMonad(node)) { | |||
| return; | |||
| } | |||
| if (RemoveRedundantDepend(node, mng)) { | |||
| changed = true; | |||
| } | |||
| @@ -126,7 +129,8 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) { | |||
| // Find depend and its free nodes. | |||
| for (const auto &node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend) || | |||
| HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) { | |||
| continue; | |||
| } | |||
| @@ -177,6 +177,7 @@ bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | |||
| std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>(); | |||
| pass->Run(func_graph); | |||
| } | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| @@ -494,8 +494,8 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||
| prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, | |||
| prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||
| prim::kPrimCast, prim::kPrimExpandDims}; | |||
| prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||
| prim::kPrimAssign, prim::kPrimExpandDims}; | |||
| #else | |||
| std::vector<PrimitivePtr> fusible_basic_ops; | |||
| #endif | |||
| @@ -629,7 +629,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||
| } | |||
| GetValidKernelNodes(); | |||
| // call CostModel to get a split plan. | |||
| if (!SplitByCostModel() || split_plan_.size() != need_inline_.size()) { | |||
| if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) { | |||
| split_plan_.clear(); | |||
| need_inline_.clear(); | |||
| return; | |||
| @@ -103,28 +103,23 @@ bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr ¶m_user) | |||
| return result; | |||
| } | |||
| AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr ¶m_user) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user}; | |||
| auto cd_node = func_graph->NewCNode(cd_inputs); | |||
| func_graph->AddNode(cd_node); | |||
| return cd_node; | |||
| } | |||
| void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto output_tuple = func_graph->output()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_tuple); | |||
| auto cur_node = output_tuple->input(1); | |||
| for (const auto &cd : cd_nodes) { | |||
| AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd}; | |||
| auto depend_node = func_graph->NewCNode(depend_inputs); | |||
| depend_node->set_abstract(depend_inputs[1]->abstract()); | |||
| cur_node = depend_node; | |||
| } | |||
| mng->Replace(output_tuple->input(1), cur_node); | |||
| void KeepExecOrder(const FuncGraphPtr &func_graph, const AnfNodePtr &gk_node, const AnfNodePtr &par_user_node, | |||
| const FuncGraphManagerPtr &mng) { | |||
| // Insert update_state_node, need mount a monad node. | |||
| auto u = NewValueNode(kUMonad); | |||
| u->set_abstract(kUMonad->ToAbstract()); | |||
| AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, gk_node}; | |||
| auto update_state_node = func_graph->NewCNode(update_state_inputs); | |||
| update_state_node->set_abstract(gk_node->abstract()); | |||
| func_graph->AddNode(update_state_node); | |||
| // Insert load_node | |||
| AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), par_user_node, update_state_node}; | |||
| auto load_node = func_graph->NewCNode(load_inputs); | |||
| load_node->set_abstract(par_user_node->abstract()); | |||
| func_graph->AddNode(load_node); | |||
| mng->Replace(gk_node, par_user_node); | |||
| } | |||
| int64_t GetitemIndex(const AnfNodePtr &getitem) { | |||
| @@ -133,11 +128,10 @@ int64_t GetitemIndex(const AnfNodePtr &getitem) { | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, | |||
| const AnfNodePtr &assign_to, int64_t removed_index) { | |||
| void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, const AnfNodePtr &assign_to, | |||
| int64_t removed_index) { | |||
| auto mng = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| AnfNodePtrList cd_nodes; | |||
| for (const auto &getitem_iter : mng->node_users()[cnode]) { | |||
| auto getitem = getitem_iter.first; | |||
| if (GetitemIndex(getitem) != removed_index) continue; | |||
| @@ -152,13 +146,10 @@ AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const An | |||
| if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) { | |||
| continue; | |||
| } | |||
| // keep execution order: cnode -> getitem_user | |||
| auto cd_node = AddControlDepend(func_graph, getitem, getitem_user); | |||
| cd_nodes.push_back(cd_node); | |||
| KeepExecOrder(func_graph, cnode, getitem_user, mng); | |||
| } | |||
| break; | |||
| } | |||
| return cd_nodes; | |||
| } | |||
| bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { | |||
| @@ -166,7 +157,6 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| bool changed = false; | |||
| AnfNodePtrList control_depend_nodes; | |||
| for (const auto &n : todos) { | |||
| if (!AnfAlgo::IsGraphKernel(n)) continue; | |||
| auto cnode = n->cast<CNodePtr>(); | |||
| @@ -174,11 +164,9 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { | |||
| if (replaceable_nodes.empty()) continue; | |||
| changed = true; | |||
| for (const auto &iter : replaceable_nodes) { | |||
| auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first); | |||
| control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end()); | |||
| UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first); | |||
| } | |||
| } | |||
| LinkControlDepends(func_graph, control_depend_nodes); | |||
| return changed; | |||
| } | |||
| @@ -97,7 +97,8 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn, | |||
| void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| for (auto &[node, node_rel] : (*node_rels)) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimDepend) || | |||
| HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) { | |||
| continue; | |||
| } | |||
| @@ -118,96 +119,6 @@ void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels); | |||
| } | |||
| std::tuple<std::pair<AnfNodePtr, AnfNodePtr>, std::pair<AnfNodePtrList, AnfNodePtrList>> FindRelationOfControlDepend( | |||
| const AnfNodePtr &node, OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto prior_node = cnode->input(kControlDependPriorIndex); | |||
| auto behind_node = cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(behind_node); | |||
| OrderedSet<AnfNodePtr> prior_nodes; | |||
| prior_nodes.insert(prior_node); | |||
| OrderedSet<AnfNodePtr> behind_nodes; | |||
| behind_nodes.insert(behind_node); | |||
| int64_t depend_mode = 0; | |||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||
| depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode); | |||
| } | |||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||
| prior_nodes = (*node_rels)[prior_node].nexts; | |||
| } | |||
| if (behind_node->isa<Parameter>()) { | |||
| behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet<AnfNodePtr>(); | |||
| } | |||
| // Get real nodes. | |||
| AnfNodePtrList real_prior_nodes; | |||
| std::set<AnfNodePtr> prior_visited; | |||
| for (const auto &tmp : prior_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||
| } | |||
| AnfNodePtrList real_behind_nodes; | |||
| std::set<AnfNodePtr> behind_visited; | |||
| for (const auto &tmp : behind_nodes) { | |||
| AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited); | |||
| } | |||
| return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes)); | |||
| } | |||
| void ReLinkNodesOfControlDependByRelation(const std::unordered_map<AnfNodePtr, AnfNodePtrList> &control_depend_info, | |||
| OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| // Relink and its log. | |||
| for (const auto &m : control_depend_info) { | |||
| const auto &prior = m.second[0]; | |||
| const auto &behind = m.second[1]; | |||
| (*node_rels)[prior].nexts.insert(behind); | |||
| (*node_rels)[behind].pres.insert(prior); | |||
| MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope() | |||
| << " -> " << behind->fullname_with_scope(); | |||
| } | |||
| } | |||
| void ProcessControlDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| std::unordered_map<AnfNodePtr, AnfNodePtrList> control_depend_info; | |||
| AnfNodePtrList latter_to_be_erased; | |||
| // Collect ControlDepend node and its input and output nodes. | |||
| for (auto &[node, node_rel] : (*node_rels)) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels); | |||
| auto &[prior_node, behind_node] = direct_relation; | |||
| auto &[real_prior_nodes, real_behind_nodes] = real_relations; | |||
| (*node_rels)[prior_node].nexts.erase(node); | |||
| (*node_rels)[behind_node].nexts.erase(node); | |||
| node_rel.pres.erase(prior_node); | |||
| node_rel.pres.erase(behind_node); | |||
| for (auto &first_node : real_prior_nodes) { | |||
| for (auto &second_node : real_behind_nodes) { | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| control_depend_info.insert({node, {first_node, second_node}}); | |||
| } | |||
| } | |||
| latter_to_be_erased.push_back(node); | |||
| } | |||
| // Delete ControlDepend node before relink its relation. | |||
| for (const auto &node : latter_to_be_erased) { | |||
| node_rels->erase(node); | |||
| } | |||
| // Rebuild relation between prior and behind node. | |||
| ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels); | |||
| } | |||
| void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { | |||
| AnfNodePtrList latter_to_be_erased; | |||
| for (auto &[node, node_rel] : (*node_rels)) { | |||
| @@ -538,7 +449,6 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An | |||
| } | |||
| ProcessDependCNode(&node_rels); | |||
| ProcessControlDependCNode(&node_rels); | |||
| ProcessThroughPassCNode( | |||
| [](const AnfNodePtr &node) { | |||
| return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/split_assign.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef SplitAssign::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<Var>(); | |||
| VarPtr Us = std::make_shared<Var>(); | |||
| VarPtr UMonad = std::make_shared<Var>(); | |||
| return VectorRef({prim::kPrimAssign, Xs, Us, UMonad}); | |||
| } | |||
| const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| CheckCNodeInputSize(cnode, kAssignInputTensorNum); | |||
| // Get original assign op's abstract and inputs | |||
| AbstractBasePtr original_abstract = cnode->abstract()->Clone(); | |||
| auto original_inputs = cnode->inputs(); | |||
| // Create depend node | |||
| AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]}; | |||
| auto depend_cnode = func_graph->NewCNode(depend_inputs); | |||
| depend_cnode->set_abstract(original_inputs[1]->abstract()); | |||
| depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| // Create new assign node, delete U from inputs. | |||
| AnfNodePtrList new_assign_inputs = {NewValueNode(prim::kPrimAssign), depend_cnode, original_inputs[2]}; | |||
| auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs); | |||
| new_assign_cnode->set_abstract(original_abstract); | |||
| new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||
| return new_assign_cnode; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class SplitAssign : public PatternProcessPass { | |||
| public: | |||
| explicit SplitAssign(bool multigraph = true) : PatternProcessPass("split_assign", multigraph) {} | |||
| ~SplitAssign() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||
| @@ -41,13 +41,15 @@ const BaseRef SubstituteDropout::DefinePattern() const { | |||
| void SetNewKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_type; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| } | |||
| @@ -69,15 +71,13 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() < kDropoutInputNum) { | |||
| MS_LOG(EXCEPTION) << "Dropout's input num is wrong"; | |||
| } | |||
| CheckCNodeInputSize(cnode, kDropoutInputTensorNum); | |||
| AbstractBasePtr old_abstract = cnode->abstract()->Clone(); | |||
| auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0); | |||
| ShapeVector shape_i64; | |||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); | |||
| // The primitive should use a clone, otherwise the attr seed will be overrode. | |||
| // The primitive should use a clone, otherwise the attr seed will be overridden. | |||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())}; | |||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), | |||
| static_cast<void *>(&shape[0]), kNumberTypeInt64); | |||
| @@ -249,7 +249,8 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { | |||
| if (node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| // Get ref count for cnode, except monad cnode. | |||
| if (node->isa<CNode>() && !HasAbstractMonad(node)) { | |||
| auto ak_node = node->cast<CNodePtr>(); | |||
| auto key = ak_node.get(); | |||
| MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); | |||
| @@ -314,7 +315,8 @@ void MemReuseUtil::SetKernelDefInputs() { | |||
| MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; | |||
| } | |||
| auto kernel_def = iter->second; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto ref_ptr = GetKernelInputRef(kernel, i); | |||
| if (ref_ptr != nullptr) { | |||
| // set the inputs of this kernel_def | |||
| @@ -214,7 +214,8 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph | |||
| // set real graph output node to be special who's refcount equal kMaxRefCount | |||
| for (const auto &output : graph->outputs()) { | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(output); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (output->isa<CNode>()) { | |||
| auto cnode = output->cast<CNodePtr>(); | |||
| auto input_node = cnode->input(i + 1); | |||
| @@ -364,7 +365,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { | |||
| const auto &cnodes = graph->execution_order(); | |||
| for (const auto &node : cnodes) { | |||
| std::vector<const void *> curr_ous; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| auto it = AnfAlgo::GetOutputAddr(node, i); | |||
| MS_EXCEPTION_IF_NULL(it); | |||
| auto ptr = it->GetPtr(); | |||
| @@ -374,7 +376,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { | |||
| } | |||
| (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); | |||
| std::vector<const void *> curr_ins; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| if (i + 1 >= node->inputs().size()) { | |||
| MS_LOG(EXCEPTION) << "Input index: " << i | |||
| << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); | |||
| @@ -37,7 +37,8 @@ bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, s | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||
| for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); | |||
| for (size_t output_idx = 0; output_idx < output_num; ++output_idx) { | |||
| TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; | |||
| ordered_tensors_.push_back(tensor_info); | |||
| } | |||
| @@ -51,12 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co | |||
| rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | |||
| inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | |||
| } | |||
| for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) { | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index); | |||
| @@ -170,6 +172,117 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic | |||
| return CheckSegments(segments, communication_op_node_size, segment_index); | |||
| } | |||
| // Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent | |||
| // cycle after AllReduce fused. It's a workaround. | |||
| // case 1: | |||
| // cnode_load = Load(%para2, cnode_u) | |||
| // %100 = UpdateState(cnode_u, cnode_load) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), %100) | |||
| // %110 = UpdateState(%100, xxx) | |||
| // will convert to: | |||
| // cnode_load = Load(%para2, U) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), cnode_u) | |||
| // %110 = UpdateState(cnode_u, xxx) | |||
| // | |||
| // case 2: | |||
| // cnode_load = Load(%para2, cnode_u) | |||
| // %99 = make_tuple(yyy, ..., cnode_load, ...) | |||
| // %100 = UpdateState(cnode_u, %99) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), %100) | |||
| // %110 = UpdateState(%100, xxx) | |||
| // will convert to: | |||
| // cnode_load = Load(%para2, U) | |||
| // %99 = make_tuple(yyy, ...) | |||
| // %100 = UpdateState(cnode_u, %99) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), %100) | |||
| // %110 = UpdateState(%100, xxx) | |||
| // | |||
| // case 3: | |||
| // cnode_load = Load(%para2, cnode_u) | |||
| // %99 = make_tuple(cnode_load) | |||
| // %100 = UpdateState(cnode_u, %99) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), %100) | |||
| // %110 = UpdateState(%100, xxx) | |||
| // will convert to: | |||
| // cnode_load = Load(%para2, U) | |||
| // ... | |||
| // %109 = AssignAdd(%para485, Tensor(34), cnode_u) | |||
| // %110 = UpdateState(cnode_u, xxx) | |||
| static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) { | |||
| auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) { | |||
| if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) { | |||
| return false; | |||
| } | |||
| if (search_cnode->inputs().size() != 3) { | |||
| MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString(); | |||
| } | |||
| return search_cnode->input(2)->isa<CNode>(); | |||
| }); | |||
| if (cnode_load != nullptr) { | |||
| const auto &const_u_monad = NewValueNode(kUMonad); | |||
| const auto &cnode_u = cnode_load->input(2); | |||
| MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); | |||
| auto manager = cnode->func_graph()->manager(); | |||
| manager->SetEdge(cnode_load, 2, const_u_monad); | |||
| // Update the u_monad input of UpdateState from CNode U same as Load to constant U. | |||
| CNodePtr cnode_update_state = nullptr; | |||
| CNodePtr cnode_make_tuple = nullptr; | |||
| const auto &cnode_load_users = manager->node_users()[cnode_load]; | |||
| for (auto &load_user : cnode_load_users) { | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { | |||
| const auto &cnode_make_tuple_users = manager->node_users()[load_user.first]; | |||
| for (auto &make_tuple_user : cnode_make_tuple_users) { | |||
| if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) { | |||
| const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>(); | |||
| if (cnode_user->input(1) == cnode_u) { | |||
| cnode_update_state = cnode_user; | |||
| cnode_make_tuple = load_user.first->cast<CNodePtr>(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (cnode_update_state != nullptr) { | |||
| break; | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { | |||
| const auto &cnode_user = load_user.first->cast<CNodePtr>(); | |||
| if (cnode_user->input(1) == cnode_u) { | |||
| cnode_update_state = cnode_user; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (cnode_update_state != nullptr) { | |||
| if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == 2) { | |||
| // case 1 and case 3: Replace cnode_update_state to cnode_u; | |||
| MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString() | |||
| << " ::TO:: " << cnode_u->DebugString(); | |||
| manager->Replace(cnode_update_state, cnode_u); | |||
| } else if (cnode_make_tuple->inputs().size() > 2) { | |||
| // case 2: remove cnode_load from cnode_make_tuple; | |||
| MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString(); | |||
| const auto &make_tuple_inputs = cnode_make_tuple->inputs(); | |||
| AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1); | |||
| std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(), | |||
| [cnode_load](const auto &inp) { return inp != cnode_load; }); | |||
| auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs); | |||
| manager->Replace(cnode_make_tuple, new_cnode_make_tuple); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString() | |||
| << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, | |||
| const CommunicationOpInfo &communication_op_info, | |||
| size_t start_index, size_t end_index) const { | |||
| @@ -184,6 +297,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr | |||
| for (size_t idx = start_index; idx <= end_index; ++idx) { | |||
| auto cnode = communication_op_info.communication_op_nodes[idx]; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (idx != start_index) { | |||
| AdjustAllReduceInputWithLoad(cnode); | |||
| } | |||
| fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| } | |||
| CheckInputs(fusion_inputs); | |||
| @@ -107,9 +107,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { | |||
| auto mng = sub_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| std::vector<AnfNodePtr> todo; | |||
| std::vector<std::pair<AnfNodePtr, size_t>> graph_rets; | |||
| kernel::GetValidKernelNodes(sub_graph, &todo); | |||
| kernel::GetGraphRealOutput(sub_graph, &graph_rets); | |||
| for (auto &t : todo) { | |||
| auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>()); | |||
| @@ -37,7 +37,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt | |||
| std::vector<AnfNodePtr> plant_inputs; | |||
| std::vector<int64_t> dyn_input_sizes; | |||
| plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode_ptr); | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { | |||
| @@ -45,7 +46,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt | |||
| dyn_input_sizes.push_back(input_size); | |||
| auto make_tuple = input_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) { | |||
| size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); | |||
| for (size_t j = 0; j < tuple_input_num; ++j) { | |||
| auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); | |||
| MS_EXCEPTION_IF_NULL(dyn_input_node); | |||
| if (IsValueNode<tensor::Tensor>(dyn_input_node)) { | |||
| @@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||
| return nullptr; | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) { | |||
| return nullptr; | |||
| } | |||
| bool cnode_input_changed = false; | |||