Merge pull request !3129 from hewei/decouple_ir_frontendtags/v0.6.0-beta
| @@ -27,6 +27,7 @@ | |||||
| #include "runtime/device/kernel_info.h" | #include "runtime/device/kernel_info.h" | ||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "frontend/parallel/ops_info/operator_info.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| const std::string ToShortString(const TypeId &typeId) { | const std::string ToShortString(const TypeId &typeId) { | ||||
| @@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo | |||||
| return; | return; | ||||
| } | } | ||||
| auto operator_info = node->operator_info(); | |||||
| auto operator_info = node->GetUserData<parallel::OperatorInfo>(); | |||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||||
| if (graph_obj == nullptr || node == nullptr) { | if (graph_obj == nullptr || node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto distributed_operation_info = node->operator_info(); | |||||
| auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>(); | |||||
| if (distributed_operation_info != nullptr) { | if (distributed_operation_info != nullptr) { | ||||
| auto strategyPtr = distributed_operation_info->strategy(); | auto strategyPtr = distributed_operation_info->strategy(); | ||||
| if (strategyPtr != nullptr) { | if (strategyPtr != nullptr) { | ||||
| @@ -1,293 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/operator/ops.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| // namespace to support primitive operators | |||||
| namespace prim { | |||||
| // Arithmetic | |||||
| const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); | |||||
| const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | |||||
| const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | |||||
| const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | |||||
| const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); | |||||
| const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | |||||
| const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | |||||
| const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | |||||
| const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor"); | |||||
| const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd"); | |||||
| const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub"); | |||||
| const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp"); | |||||
| const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log"); | |||||
| const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); | |||||
| const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos"); | |||||
| const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan"); | |||||
| // Comparisons | |||||
| const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq"); | |||||
| const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt"); | |||||
| const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt"); | |||||
| const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne"); | |||||
| const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le"); | |||||
| const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge"); | |||||
| const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||||
| const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | |||||
| const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | |||||
| const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | |||||
| const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||||
| const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||||
| const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||||
| // Type introspection | |||||
| const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||||
| const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||||
| // Statements | |||||
| const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||||
| const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | |||||
| const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||||
| const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||||
| const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||||
| const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub"); | |||||
| const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select"); | |||||
| const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call"); | |||||
| const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); | |||||
| const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); | |||||
| const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col"); | |||||
| const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im"); | |||||
| const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1"); | |||||
| const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1"); | |||||
| const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||||
| const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||||
| const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||||
| const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||||
| const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto"); | |||||
| const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch"); | |||||
| const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||||
| // Structure | |||||
| const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||||
| const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||||
| const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | |||||
| const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||||
| const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | |||||
| const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||||
| const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg"); | |||||
| const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice"); | |||||
| const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||||
| const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem"); | |||||
| const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||||
| const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem"); | |||||
| const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem"); | |||||
| const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||||
| const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem"); | |||||
| const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||||
| const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||||
| const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||||
| const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr"); | |||||
| const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len"); | |||||
| const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||||
| const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||||
| const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len"); | |||||
| const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||||
| const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||||
| const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||||
| const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape"); | |||||
| const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||||
| const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||||
| const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||||
| const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||||
| const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index"); | |||||
| const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | |||||
| const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||||
| const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||||
| const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||||
| const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||||
| // Arrays | |||||
| const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||||
| const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | |||||
| const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||||
| const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||||
| const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||||
| const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||||
| const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||||
| const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||||
| const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||||
| const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||||
| const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||||
| const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||||
| const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||||
| const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||||
| const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||||
| const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||||
| const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||||
| const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||||
| const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||||
| const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||||
| const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||||
| const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||||
| const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | |||||
| const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||||
| const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||||
| const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||||
| // Maths | |||||
| const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||||
| const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||||
| const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||||
| const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||||
| const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | |||||
| const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | |||||
| const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | |||||
| const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | |||||
| const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||||
| const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||||
| const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||||
| const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | |||||
| const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||||
| const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||||
| const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||||
| const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||||
| const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||||
| const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||||
| const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||||
| const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||||
| const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||||
| const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||||
| const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||||
| // NN | |||||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||||
| const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||||
| const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||||
| const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||||
| const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||||
| const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||||
| const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||||
| const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||||
| const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||||
| const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||||
| const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||||
| const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||||
| const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||||
| const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||||
| const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||||
| const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||||
| const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||||
| const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | |||||
| const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | |||||
| const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||||
| const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||||
| const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||||
| const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | |||||
| const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | |||||
| const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | |||||
| const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | |||||
| const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | |||||
| const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | |||||
| const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | |||||
| const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | |||||
| const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||||
| const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||||
| const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | |||||
| const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | |||||
| const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||||
| const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||||
| const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||||
| const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||||
| const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||||
| const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||||
| const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||||
| const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||||
| const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||||
| const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||||
| const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||||
| // Other miscellaneous | |||||
| const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||||
| const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial"); | |||||
| const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||||
| const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||||
| const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||||
| const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||||
| const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||||
| const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||||
| const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||||
| const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||||
| const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||||
| const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||||
| const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||||
| const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend"); | |||||
| const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||||
| const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||||
| const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||||
| const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||||
| const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||||
| const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||||
| const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||||
| const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||||
| // Comm ops | |||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||||
| // Debug ops | |||||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||||
| const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||||
| const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | |||||
| // IndexedSlices | |||||
| const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||||
| // SparseTensor | |||||
| const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | |||||
| const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues"); | |||||
| const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | |||||
| const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| @@ -22,6 +22,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support primitive operators | // namespace to support primitive operators | ||||
| @@ -31,273 +32,158 @@ ValuePtr GetPythonOps(const std::string &op_name, | |||||
| bool use_signature = false); | bool use_signature = false); | ||||
| // Arithmetic | // Arithmetic | ||||
| extern const PrimitivePtr kPrimScalarAdd; | |||||
| extern const PrimitivePtr kPrimScalarSub; | |||||
| extern const PrimitivePtr kPrimScalarMul; | |||||
| extern const PrimitivePtr kPrimScalarDiv; | |||||
| extern const PrimitivePtr kPrimScalarFloordiv; | |||||
| extern const PrimitivePtr kPrimScalarMod; | |||||
| extern const PrimitivePtr kPrimScalarPow; | |||||
| extern const PrimitivePtr kPrimScalarTrunc; | |||||
| extern const PrimitivePtr kPrimScalarFloor; | |||||
| extern const PrimitivePtr kPrimScalarUadd; | |||||
| extern const PrimitivePtr kPrimScalarUsub; | |||||
| extern const PrimitivePtr kPrimScalarExp; | |||||
| extern const PrimitivePtr kPrimScalarLog; | |||||
| extern const PrimitivePtr kPrimScalarSin; | |||||
| extern const PrimitivePtr kPrimScalarCos; | |||||
| extern const PrimitivePtr kPrimScalarTan; | |||||
| inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); | |||||
| inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | |||||
| inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | |||||
| inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | |||||
| inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); | |||||
| inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | |||||
| inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | |||||
| inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | |||||
| inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor"); | |||||
| inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd"); | |||||
| inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub"); | |||||
| inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp"); | |||||
| inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log"); | |||||
| inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); | |||||
| inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos"); | |||||
| inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan"); | |||||
| // Comparisons | // Comparisons | ||||
| extern const PrimitivePtr kPrimScalarEq; | |||||
| extern const PrimitivePtr kPrimScalarLt; | |||||
| extern const PrimitivePtr kPrimScalarGt; | |||||
| extern const PrimitivePtr kPrimScalarNe; | |||||
| extern const PrimitivePtr kPrimScalarLe; | |||||
| extern const PrimitivePtr kPrimScalarGe; | |||||
| extern const PrimitivePtr kPrimBoolNot; | |||||
| extern const PrimitivePtr kPrimBoolAnd; | |||||
| extern const PrimitivePtr kPrimBoolOr; | |||||
| extern const PrimitivePtr kPrimBoolEq; | |||||
| extern const PrimitivePtr kPrimGreater; | |||||
| extern const PrimitivePtr kPrimGreaterEqual; | |||||
| extern const PrimitivePtr kPrimLess; | |||||
| extern const PrimitivePtr kPrimLessEqual; | |||||
| extern const PrimitivePtr kPrimEqual; | |||||
| extern const PrimitivePtr kPrimNotEqual; | |||||
| inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq"); | |||||
| inline const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt"); | |||||
| inline const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt"); | |||||
| inline const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne"); | |||||
| inline const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le"); | |||||
| inline const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge"); | |||||
| inline const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||||
| inline const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | |||||
| inline const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | |||||
| inline const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | |||||
| inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||||
| inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||||
| inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||||
| inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||||
| inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||||
| inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||||
| // Type introspection | // Type introspection | ||||
| extern const PrimitivePtr kPrimTypeOf; | |||||
| extern const PrimitivePtr kPrimHasType; | |||||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||||
| // Statements | |||||
| extern const PrimitivePtr kPrimSwitch; | |||||
| extern const PrimitivePtr kPrimSwitchLayer; | |||||
| extern const PrimitivePtr kPrimReturn; | |||||
| extern const PrimitivePtr kPrimAssign; | |||||
| extern const PrimitivePtr kPrimAssignAdd; | |||||
| extern const PrimitivePtr kPrimAssignSub; | |||||
| extern const PrimitivePtr kPrimSelect; | |||||
| extern const PrimitivePtr kPrimCall; | |||||
| inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); | |||||
| inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); | |||||
| inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col"); | |||||
| inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im"); | |||||
| inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1"); | |||||
| inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1"); | |||||
| extern const PrimitivePtr kPrimDistribute; | |||||
| extern const PrimitivePtr kPrimDot; | |||||
| extern const PrimitivePtr kPrimIm2Col; | |||||
| extern const PrimitivePtr kPrimCol2Im; | |||||
| extern const PrimitivePtr kPrimIm2ColV1; | |||||
| extern const PrimitivePtr kPrimCol2ImV1; | |||||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||||
| extern const PrimitivePtr kPrimResolve; | |||||
| extern const PrimitivePtr kPrimEmbed; | |||||
| extern const PrimitivePtr kPrimRefToEmbed; | |||||
| extern const PrimitivePtr kPrimCreateInstance; | |||||
| extern const PrimitivePtr kPrimLabelGoto; | |||||
| extern const PrimitivePtr kPrimLabelSwitch; | |||||
| extern const PrimitivePtr kPrimLabelSet; | |||||
| // Structure | |||||
| extern const PrimitivePtr kPrimStringEqual; | |||||
| extern const PrimitivePtr kPrimStringConcat; | |||||
| extern const PrimitivePtr kPrimMakeTuple; | |||||
| extern const PrimitivePtr kPrimMakeList; | |||||
| extern const PrimitivePtr kPrimMakeDict; | |||||
| extern const PrimitivePtr kPrimMakeKeywordArg; | |||||
| extern const PrimitivePtr kPrimExtractKeywordArg; | |||||
| extern const PrimitivePtr kPrimMakeSlice; | |||||
| extern const PrimitivePtr kPrimMakeRecord; | |||||
| extern const PrimitivePtr kPrimTupleGetItem; | |||||
| extern const PrimitivePtr kPrimListGetItem; | |||||
| extern const PrimitivePtr kPrimArrayGetItem; | |||||
| extern const PrimitivePtr kPrimTupleSetItem; | |||||
| extern const PrimitivePtr kPrimListSetItem; | |||||
| extern const PrimitivePtr kPrimArraySetItem; | |||||
| extern const PrimitivePtr kPrimDictGetItem; | |||||
| extern const PrimitivePtr kPrimDictSetItem; | |||||
| extern const PrimitivePtr kPrimListAppend; | |||||
| extern const PrimitivePtr kPrimGetAttr; | |||||
| extern const PrimitivePtr kPrimTupleLen; | |||||
| extern const PrimitivePtr kPrimDictLen; | |||||
| extern const PrimitivePtr kPrimListLen; | |||||
| extern const PrimitivePtr kPrimArrayLen; | |||||
| extern const PrimitivePtr kPrimListMap; | |||||
| extern const PrimitivePtr kPrimListReduce; | |||||
| extern const PrimitivePtr kPrimTupleReversed; | |||||
| extern const PrimitivePtr kPrimTileShape; | |||||
| extern const PrimitivePtr kPrimReducedShape; | |||||
| extern const PrimitivePtr kPrimTupleDiv; | |||||
| extern const PrimitivePtr kPrimTupleToArray; | |||||
| extern const PrimitivePtr kPrimShapeMul; | |||||
| extern const PrimitivePtr kPrimGenerateShapeIndex; | |||||
| extern const PrimitivePtr kPrimGenerateInverseIndex; | |||||
| extern const PrimitivePtr kPrimTupleEqual; | |||||
| extern const PrimitivePtr kPrimListEqual; | |||||
| extern const PrimitivePtr kPrimMakeRange; | |||||
| extern const PrimitivePtr kPrimStopGradient; | |||||
| inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto"); | |||||
| inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch"); | |||||
| inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||||
| // Arrays | // Arrays | ||||
| extern const PrimitivePtr kPrimScalarToArray; | |||||
| extern const PrimitivePtr kPrimArrayToScalar; | |||||
| extern const PrimitivePtr kPrimBroadcastShape; | |||||
| extern const PrimitivePtr kPrimArrayMap; | |||||
| extern const PrimitivePtr kPrimArrayReduce; | |||||
| extern const PrimitivePtr kPrimShape; | |||||
| extern const PrimitivePtr kPrimCast; | |||||
| extern const PrimitivePtr kPrimConcat; | |||||
| extern const PrimitivePtr kPrimSqueeze; | |||||
| extern const PrimitivePtr kPrimTranspose; | |||||
| extern const PrimitivePtr kPrimGatherV2; | |||||
| extern const PrimitivePtr kPrimEmbeddingLookup; | |||||
| extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; | |||||
| extern const PrimitivePtr kPrimSize; | |||||
| extern const PrimitivePtr kPrimArgMax; | |||||
| extern const PrimitivePtr kPrimPack; | |||||
| extern const PrimitivePtr kPrimUnpack; | |||||
| extern const PrimitivePtr kPrimUnsortedSegmentMin; | |||||
| extern const PrimitivePtr kPrimUnsortedSegmentSum; | |||||
| extern const PrimitivePtr kPrimConcatOffset; | |||||
| extern const PrimitivePtr kPrimReshape; | |||||
| extern const PrimitivePtr kPrimTile; | |||||
| extern const PrimitivePtr kPrimAddN; | |||||
| extern const PrimitivePtr KPrimTransData; | |||||
| extern const PrimitivePtr kPrimNMSWithMask; | |||||
| extern const PrimitivePtr kPrimPad; | |||||
| extern const PrimitivePtr kPrimArgMaxWithValue; | |||||
| extern const PrimitivePtr kPrimRealDiv; | |||||
| extern const PrimitivePtr kPrimSqrt; | |||||
| extern const PrimitivePtr kPrimReciprocal; | |||||
| extern const PrimitivePtr kPrimExpandDims; | |||||
| // Maths | |||||
| extern const PrimitivePtr kPrimTensorAdd; | |||||
| extern const PrimitivePtr kPrimMatMul; | |||||
| extern const PrimitivePtr kPrimBatchMatMul; | |||||
| extern const PrimitivePtr kPrimMaximumGrad; | |||||
| extern const PrimitivePtr kPrimMinimumGrad; | |||||
| extern const PrimitivePtr kPrimReduceMean; | |||||
| extern const PrimitivePtr kPrimReduceSum; | |||||
| extern const PrimitivePtr kPrimReduceAll; | |||||
| extern const PrimitivePtr kPrimReduceMax; | |||||
| extern const PrimitivePtr kPrimReduceMin; | |||||
| extern const PrimitivePtr kPrimNeg; | |||||
| extern const PrimitivePtr kPrimSub; | |||||
| extern const PrimitivePtr kPrimMul; | |||||
| extern const PrimitivePtr kPrimRealDiv; | |||||
| extern const PrimitivePtr kPrimMinimum; | |||||
| extern const PrimitivePtr kPrimMaximum; | |||||
| extern const PrimitivePtr kPrimSquare; | |||||
| extern const PrimitivePtr kPrimSqrt; | |||||
| extern const PrimitivePtr kPrimEqual; | |||||
| extern const PrimitivePtr kPrimLess; | |||||
| extern const PrimitivePtr kPrimLessEqual; | |||||
| extern const PrimitivePtr kPrimCumSum; | |||||
| extern const PrimitivePtr kPrimCumProd; | |||||
| extern const PrimitivePtr kPrimSubscalar; | |||||
| extern const PrimitivePtr kPrimInplaceAdd; | |||||
| extern const PrimitivePtr kPrimInplaceSub; | |||||
| extern const PrimitivePtr kPrimPow; | |||||
| inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||||
| inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | |||||
| inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||||
| inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||||
| inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||||
| inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||||
| inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||||
| inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||||
| inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | |||||
| inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||||
| inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||||
| inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||||
| // NN | // NN | ||||
| extern const PrimitivePtr kPrimFlatten; | |||||
| extern const PrimitivePtr kPrimSoftmax; | |||||
| extern const PrimitivePtr kPrimLogSoftmax; | |||||
| extern const PrimitivePtr kPrimLogSoftmaxGrad; | |||||
| extern const PrimitivePtr kPrimApplyCenteredRMSProp; | |||||
| extern const PrimitivePtr kPrimTanh; | |||||
| extern const PrimitivePtr kPrimTanhGrad; | |||||
| extern const PrimitivePtr kPrimPooling; | |||||
| extern const PrimitivePtr kPrimPoolingGrad; | |||||
| extern const PrimitivePtr kPrimFusedBatchNorm; | |||||
| extern const PrimitivePtr kPrimBatchNorm; | |||||
| extern const PrimitivePtr kPrimBatchNormGrad; | |||||
| extern const PrimitivePtr kPrimConv2D; | |||||
| extern const PrimitivePtr kPrimMaxPool; | |||||
| extern const PrimitivePtr kPrimMaxPoolGrad; | |||||
| extern const PrimitivePtr kPrimAvgPoolGrad; | |||||
| extern const PrimitivePtr kPrimFusedBatchNormGrad; | |||||
| extern const PrimitivePtr kPrimReluGrad; | |||||
| extern const PrimitivePtr kPrimConv2DBackpropInput; | |||||
| extern const PrimitivePtr kPrimConv2DBackpropFilter; | |||||
| extern const PrimitivePtr kPrimDepthwiseConv2dNative; | |||||
| extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; | |||||
| extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; | |||||
| extern const PrimitivePtr kPrimBiasAddGrad; | |||||
| extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; | |||||
| extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; | |||||
| extern const PrimitivePtr kPrimMomentum; | |||||
| extern const PrimitivePtr kPrimApplyMomentum; | |||||
| extern const PrimitivePtr kPrimLayerNorm; | |||||
| extern const PrimitivePtr kPrimLayerNormGrad; | |||||
| extern const PrimitivePtr kPrimLayerNormXBackprop; | |||||
| extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; | |||||
| extern const PrimitivePtr kPrimDropoutGenMask; | |||||
| extern const PrimitivePtr kPrimDropoutDoMask; | |||||
| extern const PrimitivePtr kPrimOneHot; | |||||
| extern const PrimitivePtr kPrimGelu; | |||||
| extern const PrimitivePtr kPrimGeluGrad; | |||||
| extern const PrimitivePtr kPrimRelu; | |||||
| extern const PrimitivePtr kPrimReluV2; | |||||
| extern const PrimitivePtr kPrimActivation; | |||||
| extern const PrimitivePtr kPrimZerosLike; | |||||
| extern const PrimitivePtr kPrimFakeBprop; | |||||
| extern const PrimitivePtr kPrimBpropCut; | |||||
| extern const PrimitivePtr kPrimFakeQuantPerLayer; | |||||
| extern const PrimitivePtr kPrimFakeQuantPerChannel; | |||||
| extern const PrimitivePtr kPrimApplyRMSProp; | |||||
| // Other Miscellaneous | |||||
| extern const PrimitivePtr kPrimIdentity; | |||||
| extern const PrimitivePtr kPrimPartial; | |||||
| extern const PrimitivePtr kPrimJ; | |||||
| extern const PrimitivePtr kPrimEnvSetItem; | |||||
| extern const PrimitivePtr kPrimEnvGetItem; | |||||
| extern const PrimitivePtr kPrimEnvAdd; | |||||
| extern const PrimitivePtr kPrimMakeRefKey; | |||||
| extern const PrimitivePtr kPrimMakeRef; | |||||
| extern const PrimitivePtr kPrimGetRefKey; | |||||
| extern const PrimitivePtr kPrimGetRefValue; | |||||
| extern const PrimitivePtr kPrimGetRefOrigin; | |||||
| extern const PrimitivePtr kPrimInsertGradientOf; | |||||
| extern const PrimitivePtr kPrimHookBackward; | |||||
| extern const PrimitivePtr kPrimPrintShapeType; | |||||
| extern const PrimitivePtr kPrimPrint; | |||||
| extern const PrimitivePtr kPrimSameTypeShape; | |||||
| extern const PrimitivePtr kPrimCheckBprop; | |||||
| extern const PrimitivePtr kPrimDepend; | |||||
| extern const PrimitivePtr kPrimStateSetItem; | |||||
| extern const PrimitivePtr kPrimScalarSummary; | |||||
| extern const PrimitivePtr kPrimImageSummary; | |||||
| extern const PrimitivePtr kPrimTensorSummary; | |||||
| extern const PrimitivePtr kPrimHistogramSummary; | |||||
| extern const PrimitivePtr kPrimBroadcastGradientArgs; | |||||
| extern const PrimitivePtr kPrimControlDepend; | |||||
| extern const PrimitivePtr kPrimIs_; | |||||
| extern const PrimitivePtr kPrimIsNot; | |||||
| extern const PrimitivePtr kPrimInDict; | |||||
| extern const PrimitivePtr kPrimNotInDict; | |||||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | |||||
| extern const PrimitivePtr kPrimIsConsant; | |||||
| extern const PrimitivePtr kPrimEquivFormat; | |||||
| extern const PrimitivePtr kPrimDebug; | |||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||||
| inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||||
| inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||||
| inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||||
| inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||||
| inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||||
| inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | |||||
| inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | |||||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | |||||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | |||||
| inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | |||||
| std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | |||||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | |||||
| inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | |||||
| inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | |||||
| inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | |||||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | |||||
| inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||||
| inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||||
| inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | |||||
| inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | |||||
| inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||||
| inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||||
| inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||||
| inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||||
| inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimAllReduce; | |||||
| extern const PrimitivePtr kPrimMirror; | |||||
| extern const PrimitivePtr kPrimVirtualDiv; | |||||
| extern const PrimitivePtr kPrimVirtualDataset; | |||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||||
| // IndexedSlices | // IndexedSlices | ||||
| extern const PrimitivePtr kPrimMakeIndexedSlices; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetValues; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetIndices; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; | |||||
| inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||||
| inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||||
| inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices"); | |||||
| // SparseTensor | // SparseTensor | ||||
| extern const PrimitivePtr kPrimMakeSparseTensor; | |||||
| extern const PrimitivePtr kPrimSparseTensorGetValues; | |||||
| extern const PrimitivePtr kPrimSparseTensorGetIndices; | |||||
| extern const PrimitivePtr kPrimSparseTensorGetDenseShape; | |||||
| inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | |||||
| inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues"); | |||||
| inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | |||||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | |||||
| // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll | // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll | ||||
| const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | ||||
| @@ -305,22 +191,6 @@ const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | |||||
| // will be sunk(i.e. not unrolled) | // will be sunk(i.e. not unrolled) | ||||
| const int MAX_FOR_LOOP_COUNT = 600; | const int MAX_FOR_LOOP_COUNT = 600; | ||||
| class DoSignaturePrimitive : public Primitive { | |||||
| public: | |||||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||||
| : Primitive("S-Prim-" + name), function_(function) {} | |||||
| ~DoSignaturePrimitive() override = default; | |||||
| MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) | |||||
| const ValuePtr function() const { return function_; } | |||||
| private: | |||||
| ValuePtr function_; | |||||
| }; | |||||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||||
| class UnpackGraphPrimitive : public Primitive { | class UnpackGraphPrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | ||||
| @@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { | |||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| (void)cnode_set.emplace(cnode); | (void)cnode_set.emplace(cnode); | ||||
| } else { | } else { | ||||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | ||||
| @@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi | |||||
| return cnode_dist; | return cnode_dist; | ||||
| } | } | ||||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | ||||
| << " operator_info: " << (cnode->operator_info() != nullptr); | |||||
| << " operator_info: " << (operator_info != nullptr); | |||||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||||
| auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); | |||||
| if (IsParallelCareNode(cnode) && (operator_info != nullptr)) { | |||||
| auto cost = operator_info->GetForwardMemoryCostFromCNode(); | |||||
| MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; | MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; | ||||
| if (allreduce_graph_.NodeInGraph(cnode)) { | if (allreduce_graph_.NodeInGraph(cnode)) { | ||||
| @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { | |||||
| } | } | ||||
| auto para_ptr = node_ptr->cast<ParameterPtr>(); | auto para_ptr = node_ptr->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(para_ptr); | MS_EXCEPTION_IF_NULL(para_ptr); | ||||
| auto layout_ptr = para_ptr->tensor_layout(); | |||||
| auto layout_ptr = para_ptr->GetUserData<TensorLayout>(); | |||||
| if (layout_ptr == nullptr) { | if (layout_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "layout_ptr is nullptr!"; | MS_LOG(ERROR) << "layout_ptr is nullptr!"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||||
| for (auto para : graph_params) { | for (auto para : graph_params) { | ||||
| std::string name = std::static_pointer_cast<Parameter>(para)->name(); | std::string name = std::static_pointer_cast<Parameter>(para)->name(); | ||||
| std::shared_ptr<parallel::TensorLayout> tensor_layout = std::static_pointer_cast<Parameter>(para)->tensor_layout(); | |||||
| auto tensor_layout = para->GetUserData<parallel::TensorLayout>(); | |||||
| if (tensor_layout == nullptr) { | if (tensor_layout == nullptr) { | ||||
| MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | ||||
| } else { | } else { | ||||
| @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto distributed_operation_info = cnode->operator_info(); | |||||
| auto distributed_operation_info = cnode->GetUserData<OperatorInfo>(); | |||||
| if (distributed_operation_info != nullptr) { | if (distributed_operation_info != nullptr) { | ||||
| auto strategyPtr = distributed_operation_info->strategy(); | auto strategyPtr = distributed_operation_info->strategy(); | ||||
| if (strategyPtr != nullptr) { | if (strategyPtr != nullptr) { | ||||
| @@ -163,6 +163,9 @@ class OperatorInfo { | |||||
| const std::string &type() const { return type_; } | const std::string &type() const { return type_; } | ||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | ||||
| // Key for user data. | |||||
| constexpr static char key[] = "OpInfo"; | |||||
| protected: | protected: | ||||
| // needed by rec_parser | // needed by rec_parser | ||||
| std::string type_; | std::string type_; | ||||
| @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| (void)cnode->set_operator_info(operator_info); | |||||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| (void)cnode->set_operator_info(operator_info); | |||||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | ||||
| << " does not match the Prim: " << prim->name(); | << " does not match the Prim: " << prim->name(); | ||||
| } | } | ||||
| (void)cnode->set_operator_info(current_op_ptr); | |||||
| cnode->SetUserData<OperatorInfo>(current_op_ptr); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | ||||
| @@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| size_t edge_count = 0; | size_t edge_count = 0; | ||||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | auto prev_cnode = inputs[i]->cast<CNodePtr>(); | ||||
| bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | ||||
| @@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | ||||
| while (bool_result) { | while (bool_result) { | ||||
| if (IsAutoParallelCareNode(prev_cnode)) { | if (IsAutoParallelCareNode(prev_cnode)) { | ||||
| std::string edge_name = | |||||
| prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); | |||||
| auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>(); | |||||
| std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); | |||||
| // If the edge between these two operators already has been added, then the edge will not be added again. | // If the edge between these two operators already has been added, then the edge will not be added again. | ||||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | ||||
| break; | break; | ||||
| @@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| if (follow_strategy) { | if (follow_strategy) { | ||||
| // Redistribution in not allowed on the edge. | // Redistribution in not allowed on the edge. | ||||
| // Elementwise operators have the same strategy as their previous operators. | // Elementwise operators have the same strategy as their previous operators. | ||||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(), | |||||
| output_index, i - 1, false, true); | |||||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); | |||||
| } else { | } else { | ||||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(), | |||||
| output_index, i - 1, false); | |||||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); | |||||
| } | } | ||||
| // Init costs for this edge | // Init costs for this edge | ||||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | if (edge_ptr->InitEdgeCost() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | ||||
| } | } | ||||
| cnode->operator_info()->AddPrevEdge(edge_ptr); | |||||
| prev_cnode->operator_info()->AddSuccEdge(edge_ptr); | |||||
| entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); | |||||
| MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " | |||||
| << cnode->operator_info()->name(); | |||||
| node_op_info->AddPrevEdge(edge_ptr); | |||||
| prev_op_info->AddSuccEdge(edge_ptr); | |||||
| entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); | |||||
| MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " | |||||
| << node_op_info->name(); | |||||
| edge_count++; | edge_count++; | ||||
| break; | break; | ||||
| @@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); | |||||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Constructing edges for cost graph ends."; | MS_LOG(INFO) << "Constructing edges for cost graph ends."; | ||||
| @@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| for (auto &target : target_set) { | for (auto &target : target_set) { | ||||
| auto target_cnode = target.first->cast<CNodePtr>(); | auto target_cnode = target.first->cast<CNodePtr>(); | ||||
| auto input_index = target.second; | auto input_index = target.second; | ||||
| (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); | |||||
| (void)target_without_duplicate.insert(std::to_string(input_index) + | |||||
| target_cnode->GetUserData<OperatorInfo>()->name()); | |||||
| } | } | ||||
| if (target_without_duplicate.size() <= 1) { | if (target_without_duplicate.size() <= 1) { | ||||
| continue; | continue; | ||||
| @@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| auto target_cnode = target.first->cast<CNodePtr>(); | auto target_cnode = target.first->cast<CNodePtr>(); | ||||
| auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | ||||
| auto input_index = target.second; | auto input_index = target.second; | ||||
| auto target_op_info = target_cnode->GetUserData<OperatorInfo>(); | |||||
| std::string edge_name = | |||||
| std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); | |||||
| std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); | |||||
| // If the edge between these two operators already has been added, then the edge will not be added again. | // If the edge between these two operators already has been added, then the edge will not be added again. | ||||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { | if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>( | |||||
| edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); | |||||
| std::shared_ptr<Edge> edge_ptr = | |||||
| std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); | |||||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | if (edge_ptr->InitEdgeCost() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | ||||
| } | } | ||||
| target_cnode->operator_info()->AddPrevEdge(edge_ptr); | |||||
| target_op_info->AddPrevEdge(edge_ptr); | |||||
| tmp_identity_ptr->AddSuccEdge(edge_ptr); | tmp_identity_ptr->AddSuccEdge(edge_ptr); | ||||
| entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); | |||||
| entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr); | |||||
| MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " | MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " | ||||
| << target_cnode->operator_info()->name(); | |||||
| << target_op_info->name(); | |||||
| add_identity_edge = true; | add_identity_edge = true; | ||||
| } | } | ||||
| if (new_identity && add_identity_edge) { | if (new_identity && add_identity_edge) { | ||||
| @@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| if (operator_info == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||||
| } | |||||
| if (prim->name() != RESHAPE) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| return (prim->name() == RESHAPE); | |||||
| } | } | ||||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | ||||
| @@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||||
| *pre_operator_info = cnode->operator_info(); | |||||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||||
| *pre_operator_info = node_op_info; | |||||
| *out_index = 0; | *out_index = 0; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | ||||
| } | } | ||||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | ||||
| if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { | |||||
| *pre_operator_info = pre_cnode->operator_info(); | |||||
| auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>(); | |||||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||||
| *pre_operator_info = pre_op_info; | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||||
| auto op_info = use_apply->GetUserData<OperatorInfo>(); | |||||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | ||||
| *next_operator_info = use_apply->operator_info(); | |||||
| *next_operator_info = op_info; | |||||
| *in_index = node_pair.second - 1; | *in_index = node_pair.second - 1; | ||||
| return true; | return true; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | ||||
| << " " << (use_apply->operator_info() != nullptr); | |||||
| << " " << (op_info != nullptr); | |||||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | ||||
| return true; | return true; | ||||
| @@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| int32_t out_index = 0; | int32_t out_index = 0; | ||||
| OperatorInfoPtr pre_operator_info; | OperatorInfoPtr pre_operator_info; | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | ||||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| if (pre_node->isa<Parameter>()) { | if (pre_node->isa<Parameter>()) { | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | ||||
| reshape_info->SetCostForReshapeWithParameter(); | reshape_info->SetCostForReshapeWithParameter(); | ||||
| pre_operator_info = reshape_info; | pre_operator_info = reshape_info; | ||||
| @@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| // set input_layout and output_layout for reshape. | // set input_layout and output_layout for reshape. | ||||
| // init reshape and set cost for each input_layout and output_layout. | // init reshape and set cost for each input_layout and output_layout. | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | ||||
| reshape_info->set_pre_operator_name(pre_operator_info->name()); | reshape_info->set_pre_operator_name(pre_operator_info->name()); | ||||
| reshape_info->set_pre_operator_index(out_index); | reshape_info->set_pre_operator_index(out_index); | ||||
| @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { | |||||
| if (!IsParallelCareNode(node)) { | if (!IsParallelCareNode(node)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| OperatorInfoPtr distribute_operator = node->operator_info(); | |||||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | ||||
| } | } | ||||
| @@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||||
| if (prim->name() == GET_NEXT) { | if (prim->name() == GET_NEXT) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { | |||||
| if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { | |||||
| if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) { | |||||
| Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | ||||
| pre_node); | pre_node); | ||||
| } else { | } else { | ||||
| @@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||||
| void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(next_node); | MS_EXCEPTION_IF_NULL(next_node); | ||||
| OperatorInfoPtr op_info = next_node->operator_info(); | |||||
| OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| // If the shape of tensor is [] or [1], no need to split it. | // If the shape of tensor is [] or [1], no need to split it. | ||||
| @@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | ||||
| // step1:get graph manager distribute_operator | // step1:get graph manager distribute_operator | ||||
| OperatorInfoPtr distribute_operator = node->operator_info(); | |||||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | ||||
| } | } | ||||
| @@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||||
| (void)prim->SetAttrs(attrs); | (void)prim->SetAttrs(attrs); | ||||
| } | } | ||||
| if (index == replace_op.size() - 1) { | if (index == replace_op.size() - 1) { | ||||
| (void)replace_node->set_operator_info(node->operator_info()); | |||||
| replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>()); | |||||
| } | } | ||||
| replace_node->set_in_forward_flag(true); | replace_node->set_in_forward_flag(true); | ||||
| replace_input[0]->set_scope(scope); | replace_input[0]->set_scope(scope); | ||||
| @@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { | |||||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | auto pre_cnode = pre_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||||
| if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||||
| pre_node = pre_cnode->input(1); | pre_node = pre_cnode->input(1); | ||||
| } | } | ||||
| @@ -1204,7 +1204,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { | |||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| return node_pair; | return node_pair; | ||||
| } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | ||||
| return FindParallelCareNode(node_pair.first); | return FindParallelCareNode(node_pair.first); | ||||
| @@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||||
| MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | ||||
| CNodePtr cnode = res.first->cast<CNodePtr>(); | CNodePtr cnode = res.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| OperatorInfoPtr distribute_operator = cnode->operator_info(); | |||||
| OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | ||||
| } | } | ||||
| @@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | ||||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | MS_EXCEPTION_IF_NULL(parameter_ptr); | ||||
| parameter_ptr->set_tensor_layout(std::make_shared<TensorLayout>(tensor_layout)); | |||||
| parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||||
| } | } | ||||
| void CoverSliceShape(const FuncGraphPtr &root) { | void CoverSliceShape(const FuncGraphPtr &root) { | ||||
| @@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| if (found_be_cloned_parameter) { | if (found_be_cloned_parameter) { | ||||
| // set the shape and tensor layout for cloned parameter | // set the shape and tensor layout for cloned parameter | ||||
| cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); | |||||
| cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>()); | |||||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | ||||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | ||||
| @@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| (*operator_).set_outputs_dtype(cnode->Type()); | (*operator_).set_outputs_dtype(cnode->Type()); | ||||
| (*operator_).set_cnode(cnode); | (*operator_).set_cnode(cnode); | ||||
| if (prim->name() == RESHAPE) { | if (prim->name() == RESHAPE) { | ||||
| (void)cnode->set_operator_info(operator_); | |||||
| cnode->SetUserData<OperatorInfo>(operator_); | |||||
| continue; | continue; | ||||
| } | } | ||||
| // load strategy checkpoint | // load strategy checkpoint | ||||
| @@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| if (operator_->Init(strategyPtr) == FAILED) { | if (operator_->Init(strategyPtr) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | ||||
| } | } | ||||
| (void)cnode->set_operator_info(operator_); | |||||
| cnode->SetUserData<OperatorInfo>(operator_); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | ||||
| } | } | ||||
| @@ -1542,13 +1542,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) { | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||||
| if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) { | |||||
| MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | ||||
| auto layout = GetInputLayoutFromCNode(node_pair); | auto layout = GetInputLayoutFromCNode(node_pair); | ||||
| return std::make_shared<TensorLayout>(layout); | return std::make_shared<TensorLayout>(layout); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | ||||
| << " " << (use_apply->operator_info() != nullptr); | |||||
| << " " << use_apply->HasUserData<OperatorInfo>(); | |||||
| auto layout_ptr = FindNextLayout(use_apply); | auto layout_ptr = FindNextLayout(use_apply); | ||||
| if (layout_ptr) { | if (layout_ptr) { | ||||
| @@ -1580,7 +1580,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | ||||
| if (!layout_ptr) { | if (!layout_ptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | ||||
| @@ -1624,7 +1624,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | ||||
| if (!layout_ptr) { | if (!layout_ptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | ||||
| @@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | ||||
| } | } | ||||
| @@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| // return -> cast | // return -> cast | ||||
| if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||||
| if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| @@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| OperatorInfoPtr operator_info = loss_cnode->operator_info(); | |||||
| OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(operator_info); | MS_EXCEPTION_IF_NULL(operator_info); | ||||
| TensorInfo loss_grad_tensor_info; | TensorInfo loss_grad_tensor_info; | ||||
| size_t op_output_size = operator_info->outputs_tensor_info().size(); | size_t op_output_size = operator_info->outputs_tensor_info().size(); | ||||
| @@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||||
| if (sens_tensor_node->isa<Parameter>()) { | if (sens_tensor_node->isa<Parameter>()) { | ||||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | ||||
| MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | ||||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| } | } | ||||
| MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | ||||
| return; | return; | ||||
| @@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||||
| cloned_abstract->set_shape(parallel_shape); | cloned_abstract->set_shape(parallel_shape); | ||||
| sens_tensor_node->set_abstract(cloned_abstract); | sens_tensor_node->set_abstract(cloned_abstract); | ||||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | ||||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | ||||
| @@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| if (operator_info) { | if (operator_info) { | ||||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | ||||
| continue; | continue; | ||||
| @@ -83,6 +83,9 @@ class TensorLayout { | |||||
| TensorLayout SqueezeShape() const; | TensorLayout SqueezeShape() const; | ||||
| // Key for user data. | |||||
| constexpr static char key[] = "TLayout"; | |||||
| private: | private: | ||||
| std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement( | std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement( | ||||
| const Arrangement &expanded_shape) const; | const Arrangement &expanded_shape) const; | ||||
| @@ -0,0 +1,160 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| #define MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "ir/primitive.h" | |||||
| namespace mindspore { | |||||
| namespace prim { | |||||
| // Maths | |||||
| inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||||
| inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | |||||
| inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | |||||
| inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | |||||
| inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | |||||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||||
| inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | |||||
| inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||||
| inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||||
| inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||||
| inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||||
| inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||||
| inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||||
| inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||||
| inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||||
| inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||||
| inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||||
| inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||||
| // Statements | |||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||||
| inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||||
| inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | |||||
| inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||||
| inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||||
| inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub"); | |||||
| inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select"); | |||||
| inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call"); | |||||
| // Structures | |||||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||||
| inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | |||||
| inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | |||||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||||
| inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||||
| inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice"); | |||||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||||
| inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem"); | |||||
| inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||||
| inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem"); | |||||
| inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem"); | |||||
| inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||||
| inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem"); | |||||
| inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||||
| inline const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr"); | |||||
| inline const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len"); | |||||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||||
| inline const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len"); | |||||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||||
| inline const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape"); | |||||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||||
| inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index"); | |||||
| inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | |||||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||||
| inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg"); | |||||
| // Debug ops | |||||
| inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||||
| inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||||
| inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||||
| inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||||
| inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | |||||
| // Other miscellaneous | |||||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||||
| inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend"); | |||||
| inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial"); | |||||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||||
| inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||||
| inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||||
| inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||||
| inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||||
| inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||||
| inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||||
| inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||||
| inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||||
| inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||||
| inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||||
| inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||||
| inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||||
| inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||||
| class DoSignaturePrimitive : public Primitive { | |||||
| public: | |||||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||||
| : Primitive("S-Prim-" + name), function_(function) {} | |||||
| ~DoSignaturePrimitive() override = default; | |||||
| MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) | |||||
| const ValuePtr function() const { return function_; } | |||||
| private: | |||||
| ValuePtr function_; | |||||
| }; | |||||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPERATOR_OPS_H_ | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_USER_DATA_H_ | |||||
| #define MINDSPORE_CORE_USER_DATA_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| namespace mindspore { | |||||
| class UserData { | |||||
| public: | |||||
| template <typename T> | |||||
| void set(const std::string &key, const std::shared_ptr<T> &value) { | |||||
| if (value == nullptr) { | |||||
| data_.erase(key); | |||||
| } else { | |||||
| data_.insert_or_assign(key, value); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| std::shared_ptr<T> get(const std::string &key) const { | |||||
| auto iter = data_.find(key); | |||||
| if (iter == data_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::static_pointer_cast<T>(iter->second); | |||||
| } | |||||
| bool has(const std::string &key) const { return data_.find(key) != data_.end(); } | |||||
| private: | |||||
| std::map<std::string, std::shared_ptr<void>> data_; | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_USER_DATA_H_ | |||||
| @@ -26,7 +26,6 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "base/user_data.h" | |||||
| #include "ir/kernel_info_dev.h" | #include "ir/kernel_info_dev.h" | ||||
| #include "ir/scope.h" | #include "ir/scope.h" | ||||
| #include "debug/info.h" | #include "debug/info.h" | ||||
| @@ -41,12 +42,6 @@ | |||||
| // ANode: Atomic Node | // ANode: Atomic Node | ||||
| // CNode: Complex Node | // CNode: Complex Node | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | |||||
| class TensorLayout; | |||||
| class OperatorInfo; | |||||
| } // namespace parallel | |||||
| using OperatorInfoPtr = std::shared_ptr<parallel::OperatorInfo>; | |||||
| namespace abstract { | namespace abstract { | ||||
| class BaseShape; | class BaseShape; | ||||
| class AbstractBase; | class AbstractBase; | ||||
| @@ -157,6 +152,31 @@ class AnfNode : public Base { | |||||
| } | } | ||||
| size_t seen_{0}; | size_t seen_{0}; | ||||
| template <typename T> | |||||
| void SetUserData(const std::string &key, const std::shared_ptr<T> &value) { | |||||
| user_data_.set<T>(key, value); | |||||
| } | |||||
| template <typename T> | |||||
| void SetUserData(const std::shared_ptr<T> &value) { | |||||
| user_data_.set<T>(T::key, value); | |||||
| } | |||||
| template <typename T> | |||||
| std::shared_ptr<T> GetUserData(const std::string &key) const { | |||||
| return user_data_.get<T>(key); | |||||
| } | |||||
| template <typename T> | |||||
| std::shared_ptr<T> GetUserData() const { | |||||
| return user_data_.get<T>(T::key); | |||||
| } | |||||
| bool HasUserData(const std::string &key) const { return user_data_.has(key); } | |||||
| template <typename T> | |||||
| bool HasUserData() const { return user_data_.has(T::key); } | |||||
| protected: | protected: | ||||
| // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | ||||
| // Otherwise, func_graph_ and AnfNode will make a reference cycle. | // Otherwise, func_graph_ and AnfNode will make a reference cycle. | ||||
| @@ -170,6 +190,7 @@ class AnfNode : public Base { | |||||
| std::hash<const AnfNode *> hash_; | std::hash<const AnfNode *> hash_; | ||||
| ScopePtr scope_; | ScopePtr scope_; | ||||
| KernelInfoDevicePtr kernel_info_; | KernelInfoDevicePtr kernel_info_; | ||||
| UserData user_data_; | |||||
| }; | }; | ||||
| // CNode represents the complex node with a set of arguments. | // CNode represents the complex node with a set of arguments. | ||||
| @@ -212,9 +233,6 @@ class CNode : public AnfNode { | |||||
| std::string DebugString(int recursive_level = 1) const override; | std::string DebugString(int recursive_level = 1) const override; | ||||
| std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | ||||
| OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); | |||||
| OperatorInfoPtr operator_info() { return operator_info_; } | |||||
| void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } | void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } | ||||
| bool in_forward_flag() const { return in_forward_flag_; } | bool in_forward_flag() const { return in_forward_flag_; } | ||||
| @@ -224,7 +242,6 @@ class CNode : public AnfNode { | |||||
| std::vector<AnfNodePtr> inputs_; | std::vector<AnfNodePtr> inputs_; | ||||
| VarPtr func_graph_as_var_; | VarPtr func_graph_as_var_; | ||||
| bool stop_gradient_; | bool stop_gradient_; | ||||
| OperatorInfoPtr operator_info_ = nullptr; | |||||
| bool in_forward_flag_ = false; | bool in_forward_flag_ = false; | ||||
| }; | }; | ||||
| @@ -244,7 +261,7 @@ class ANode : public AnfNode { | |||||
| class Parameter : public ANode { | class Parameter : public ANode { | ||||
| public: | public: | ||||
| explicit Parameter(const FuncGraphPtr &func_graph) | explicit Parameter(const FuncGraphPtr &func_graph) | ||||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} | |||||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} | |||||
| ~Parameter() override = default; | ~Parameter() override = default; | ||||
| MS_DECLARE_PARENT(Parameter, ANode); | MS_DECLARE_PARENT(Parameter, ANode); | ||||
| @@ -261,11 +278,6 @@ class Parameter : public ANode { | |||||
| } | } | ||||
| ParamValuePtr default_param() const { return default_param_; } | ParamValuePtr default_param() const { return default_param_; } | ||||
| std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; } | |||||
| void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) { | |||||
| tensor_layout_ = tensor_layout; | |||||
| } | |||||
| bool operator==(const AnfNode &other) const override { | bool operator==(const AnfNode &other) const override { | ||||
| if (!other.isa<Parameter>()) { | if (!other.isa<Parameter>()) { | ||||
| return false; | return false; | ||||
| @@ -281,7 +293,6 @@ class Parameter : public ANode { | |||||
| std::string name_; | std::string name_; | ||||
| bool has_default_; | bool has_default_; | ||||
| ParamValuePtr default_param_; | ParamValuePtr default_param_; | ||||
| std::shared_ptr<parallel::TensorLayout> tensor_layout_; | |||||
| }; | }; | ||||
| using ParameterPtr = std::shared_ptr<Parameter>; | using ParameterPtr = std::shared_ptr<Parameter>; | ||||
| @@ -23,8 +23,7 @@ | |||||
| #include "ir/visitor.h" | #include "ir/visitor.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "debug/label.h" | #include "debug/label.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -37,18 +36,6 @@ std::string AnfNode::ToString() const { | |||||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | ||||
| } | } | ||||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||||
| if (operator_info_ != nullptr) { | |||||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | |||||
| << ", using the new one: " << operator_info->name(); | |||||
| auto old_ptr = operator_info_; | |||||
| operator_info_ = operator_info; | |||||
| return old_ptr; | |||||
| } | |||||
| operator_info_ = operator_info; | |||||
| return nullptr; | |||||
| } | |||||
| std::string CNode::fullname_with_scope() { | std::string CNode::fullname_with_scope() { | ||||
| // if full name is set, return its name immediately | // if full name is set, return its name immediately | ||||
| if (!fullname_with_scope_.empty()) { | if (!fullname_with_scope_.empty()) { | ||||
| @@ -24,7 +24,6 @@ | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "utils/ordered_set.h" | #include "utils/ordered_set.h" | ||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/param_value.h" | #include "ir/param_value.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/profile.h" | #include "utils/profile.h" | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/ordered_set.h" | #include "utils/ordered_set.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "utils/profile.h" | #include "utils/profile.h" | ||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -17,10 +17,8 @@ | |||||
| */ | */ | ||||
| #include "ir/meta_func_graph.h" | #include "ir/meta_func_graph.h" | ||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -22,9 +22,9 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <vector> | #include <vector> | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /// | /// | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "frontend/parallel/ops_info/operator_info.h" | |||||
| #include "utils/base_ref_extends.h" | #include "utils/base_ref_extends.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -18,7 +18,6 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <utility> | #include <utility> | ||||
| #include "ir/signature.h" | #include "ir/signature.h" | ||||
| #include "frontend/operator/ops.h" | |||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| @@ -28,7 +28,6 @@ | |||||
| #include <type_traits> | #include <type_traits> | ||||
| #include <typeinfo> | #include <typeinfo> | ||||
| #include "runtime/device/device_address.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { | |||||
| StrategyPtr strategyPtr; | StrategyPtr strategyPtr; | ||||
| std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | ||||
| node->set_operator_info(matmul_info); | |||||
| node->SetUserData<OperatorInfo>(matmul_info); | |||||
| std::string name_expect = "MatMulInfo00"; | std::string name_expect = "MatMulInfo00"; | ||||
| std::string name_test = matmul_info->name(); | std::string name_test = matmul_info->name(); | ||||
| ASSERT_EQ(name_expect, name_test); | ASSERT_EQ(name_expect, name_test); | ||||
| @@ -525,8 +525,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { | |||||
| std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | ||||
| OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | ||||
| matmul_info->Init(strategyPtr); | matmul_info->Init(strategyPtr); | ||||
| node->set_operator_info(matmul_info); | |||||
| OperatorInfoPtr distribute_operator_pre = node->operator_info(); | |||||
| node->SetUserData<OperatorInfo>(matmul_info); | |||||
| OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>(); | |||||
| TensorLayout tensorlayout_e; | TensorLayout tensorlayout_e; | ||||
| std::vector<int32_t> array = {64, 64}; | std::vector<int32_t> array = {64, 64}; | ||||
| TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | ||||