Merge pull request !3129 from hewei/decouple_ir_frontendtags/v0.6.0-beta
| @@ -27,6 +27,7 @@ | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| namespace mindspore { | |||
| const std::string ToShortString(const TypeId &typeId) { | |||
| @@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo | |||
| return; | |||
| } | |||
| auto operator_info = node->operator_info(); | |||
| auto operator_info = node->GetUserData<parallel::OperatorInfo>(); | |||
| if (operator_info == nullptr) { | |||
| return; | |||
| } | |||
| @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr) { | |||
| return; | |||
| } | |||
| auto distributed_operation_info = node->operator_info(); | |||
| auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>(); | |||
| if (distributed_operation_info != nullptr) { | |||
| auto strategyPtr = distributed_operation_info->strategy(); | |||
| if (strategyPtr != nullptr) { | |||
| @@ -1,293 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/operator/ops.h" | |||
| #include <memory> | |||
| #include <string> | |||
| namespace mindspore { | |||
| // namespace to support primitive operators | |||
| namespace prim { | |||
| // Arithmetic | |||
| const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); | |||
| const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | |||
| const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | |||
| const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | |||
| const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); | |||
| const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | |||
| const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | |||
| const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | |||
| const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor"); | |||
| const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd"); | |||
| const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub"); | |||
| const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp"); | |||
| const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log"); | |||
| const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); | |||
| const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos"); | |||
| const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan"); | |||
| // Comparisons | |||
| const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq"); | |||
| const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt"); | |||
| const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt"); | |||
| const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne"); | |||
| const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le"); | |||
| const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge"); | |||
| const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||
| const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | |||
| const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | |||
| const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | |||
| const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||
| const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||
| const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||
| const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||
| const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||
| // Type introspection | |||
| const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||
| const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||
| // Statements | |||
| const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||
| const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | |||
| const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||
| const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||
| const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||
| const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub"); | |||
| const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select"); | |||
| const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call"); | |||
| const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); | |||
| const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); | |||
| const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col"); | |||
| const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im"); | |||
| const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1"); | |||
| const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1"); | |||
| const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||
| const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||
| const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||
| const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||
| const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto"); | |||
| const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch"); | |||
| const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||
| // Structure | |||
| const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||
| const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||
| const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | |||
| const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||
| const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | |||
| const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||
| const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg"); | |||
| const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice"); | |||
| const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||
| const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem"); | |||
| const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||
| const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem"); | |||
| const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem"); | |||
| const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||
| const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem"); | |||
| const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||
| const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr"); | |||
| const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len"); | |||
| const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||
| const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len"); | |||
| const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||
| const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||
| const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||
| const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape"); | |||
| const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||
| const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||
| const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||
| const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||
| const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index"); | |||
| const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | |||
| const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||
| const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||
| const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||
| const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||
| // Arrays | |||
| const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||
| const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | |||
| const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||
| const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||
| const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||
| const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||
| const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||
| const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||
| const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||
| const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||
| const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||
| const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||
| const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||
| const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||
| const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||
| const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||
| const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||
| const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||
| const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | |||
| const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||
| const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||
| const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||
| // Maths | |||
| const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||
| const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||
| const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||
| const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||
| const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | |||
| const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | |||
| const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | |||
| const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | |||
| const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||
| const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||
| const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||
| const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | |||
| const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||
| const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||
| const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||
| const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||
| const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||
| const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||
| const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||
| const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||
| const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||
| const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||
| const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||
| const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||
| const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||
| const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||
| // NN | |||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||
| const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||
| const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||
| const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||
| const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||
| const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||
| const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||
| const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||
| const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||
| const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||
| const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||
| const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | |||
| const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | |||
| const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||
| const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||
| const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||
| const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | |||
| const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | |||
| const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | |||
| const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | |||
| const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | |||
| const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | |||
| const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | |||
| const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | |||
| const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||
| const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||
| const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | |||
| const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | |||
| const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||
| const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||
| const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||
| const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||
| const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||
| const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||
| const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||
| const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||
| const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||
| // Other miscellaneous | |||
| const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||
| const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial"); | |||
| const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||
| const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||
| const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||
| const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||
| const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||
| const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||
| const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend"); | |||
| const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||
| const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||
| const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||
| const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||
| const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||
| const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||
| const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||
| // Comm ops | |||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| // Debug ops | |||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||
| const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | |||
| // IndexedSlices | |||
| const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||
| const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||
| const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||
| const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||
| // SparseTensor | |||
| const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | |||
| const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues"); | |||
| const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | |||
| const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| // namespace to support primitive operators | |||
| @@ -31,273 +32,158 @@ ValuePtr GetPythonOps(const std::string &op_name, | |||
| bool use_signature = false); | |||
| // Arithmetic | |||
| extern const PrimitivePtr kPrimScalarAdd; | |||
| extern const PrimitivePtr kPrimScalarSub; | |||
| extern const PrimitivePtr kPrimScalarMul; | |||
| extern const PrimitivePtr kPrimScalarDiv; | |||
| extern const PrimitivePtr kPrimScalarFloordiv; | |||
| extern const PrimitivePtr kPrimScalarMod; | |||
| extern const PrimitivePtr kPrimScalarPow; | |||
| extern const PrimitivePtr kPrimScalarTrunc; | |||
| extern const PrimitivePtr kPrimScalarFloor; | |||
| extern const PrimitivePtr kPrimScalarUadd; | |||
| extern const PrimitivePtr kPrimScalarUsub; | |||
| extern const PrimitivePtr kPrimScalarExp; | |||
| extern const PrimitivePtr kPrimScalarLog; | |||
| extern const PrimitivePtr kPrimScalarSin; | |||
| extern const PrimitivePtr kPrimScalarCos; | |||
| extern const PrimitivePtr kPrimScalarTan; | |||
| inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add"); | |||
| inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub"); | |||
| inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul"); | |||
| inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div"); | |||
| inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv"); | |||
| inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod"); | |||
| inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow"); | |||
| inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc"); | |||
| inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor"); | |||
| inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd"); | |||
| inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub"); | |||
| inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp"); | |||
| inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log"); | |||
| inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin"); | |||
| inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos"); | |||
| inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan"); | |||
| // Comparisons | |||
| extern const PrimitivePtr kPrimScalarEq; | |||
| extern const PrimitivePtr kPrimScalarLt; | |||
| extern const PrimitivePtr kPrimScalarGt; | |||
| extern const PrimitivePtr kPrimScalarNe; | |||
| extern const PrimitivePtr kPrimScalarLe; | |||
| extern const PrimitivePtr kPrimScalarGe; | |||
| extern const PrimitivePtr kPrimBoolNot; | |||
| extern const PrimitivePtr kPrimBoolAnd; | |||
| extern const PrimitivePtr kPrimBoolOr; | |||
| extern const PrimitivePtr kPrimBoolEq; | |||
| extern const PrimitivePtr kPrimGreater; | |||
| extern const PrimitivePtr kPrimGreaterEqual; | |||
| extern const PrimitivePtr kPrimLess; | |||
| extern const PrimitivePtr kPrimLessEqual; | |||
| extern const PrimitivePtr kPrimEqual; | |||
| extern const PrimitivePtr kPrimNotEqual; | |||
| inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq"); | |||
| inline const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt"); | |||
| inline const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt"); | |||
| inline const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne"); | |||
| inline const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le"); | |||
| inline const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge"); | |||
| inline const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); | |||
| inline const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); | |||
| inline const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); | |||
| inline const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); | |||
| inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater"); | |||
| inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); | |||
| inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); | |||
| inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); | |||
| inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); | |||
| // Type introspection | |||
| extern const PrimitivePtr kPrimTypeOf; | |||
| extern const PrimitivePtr kPrimHasType; | |||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||
| // Statements | |||
| extern const PrimitivePtr kPrimSwitch; | |||
| extern const PrimitivePtr kPrimSwitchLayer; | |||
| extern const PrimitivePtr kPrimReturn; | |||
| extern const PrimitivePtr kPrimAssign; | |||
| extern const PrimitivePtr kPrimAssignAdd; | |||
| extern const PrimitivePtr kPrimAssignSub; | |||
| extern const PrimitivePtr kPrimSelect; | |||
| extern const PrimitivePtr kPrimCall; | |||
| inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); | |||
| inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); | |||
| inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col"); | |||
| inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im"); | |||
| inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1"); | |||
| inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1"); | |||
| extern const PrimitivePtr kPrimDistribute; | |||
| extern const PrimitivePtr kPrimDot; | |||
| extern const PrimitivePtr kPrimIm2Col; | |||
| extern const PrimitivePtr kPrimCol2Im; | |||
| extern const PrimitivePtr kPrimIm2ColV1; | |||
| extern const PrimitivePtr kPrimCol2ImV1; | |||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||
| extern const PrimitivePtr kPrimResolve; | |||
| extern const PrimitivePtr kPrimEmbed; | |||
| extern const PrimitivePtr kPrimRefToEmbed; | |||
| extern const PrimitivePtr kPrimCreateInstance; | |||
| extern const PrimitivePtr kPrimLabelGoto; | |||
| extern const PrimitivePtr kPrimLabelSwitch; | |||
| extern const PrimitivePtr kPrimLabelSet; | |||
| // Structure | |||
| extern const PrimitivePtr kPrimStringEqual; | |||
| extern const PrimitivePtr kPrimStringConcat; | |||
| extern const PrimitivePtr kPrimMakeTuple; | |||
| extern const PrimitivePtr kPrimMakeList; | |||
| extern const PrimitivePtr kPrimMakeDict; | |||
| extern const PrimitivePtr kPrimMakeKeywordArg; | |||
| extern const PrimitivePtr kPrimExtractKeywordArg; | |||
| extern const PrimitivePtr kPrimMakeSlice; | |||
| extern const PrimitivePtr kPrimMakeRecord; | |||
| extern const PrimitivePtr kPrimTupleGetItem; | |||
| extern const PrimitivePtr kPrimListGetItem; | |||
| extern const PrimitivePtr kPrimArrayGetItem; | |||
| extern const PrimitivePtr kPrimTupleSetItem; | |||
| extern const PrimitivePtr kPrimListSetItem; | |||
| extern const PrimitivePtr kPrimArraySetItem; | |||
| extern const PrimitivePtr kPrimDictGetItem; | |||
| extern const PrimitivePtr kPrimDictSetItem; | |||
| extern const PrimitivePtr kPrimListAppend; | |||
| extern const PrimitivePtr kPrimGetAttr; | |||
| extern const PrimitivePtr kPrimTupleLen; | |||
| extern const PrimitivePtr kPrimDictLen; | |||
| extern const PrimitivePtr kPrimListLen; | |||
| extern const PrimitivePtr kPrimArrayLen; | |||
| extern const PrimitivePtr kPrimListMap; | |||
| extern const PrimitivePtr kPrimListReduce; | |||
| extern const PrimitivePtr kPrimTupleReversed; | |||
| extern const PrimitivePtr kPrimTileShape; | |||
| extern const PrimitivePtr kPrimReducedShape; | |||
| extern const PrimitivePtr kPrimTupleDiv; | |||
| extern const PrimitivePtr kPrimTupleToArray; | |||
| extern const PrimitivePtr kPrimShapeMul; | |||
| extern const PrimitivePtr kPrimGenerateShapeIndex; | |||
| extern const PrimitivePtr kPrimGenerateInverseIndex; | |||
| extern const PrimitivePtr kPrimTupleEqual; | |||
| extern const PrimitivePtr kPrimListEqual; | |||
| extern const PrimitivePtr kPrimMakeRange; | |||
| extern const PrimitivePtr kPrimStopGradient; | |||
| inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto"); | |||
| inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch"); | |||
| inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||
| // Arrays | |||
| extern const PrimitivePtr kPrimScalarToArray; | |||
| extern const PrimitivePtr kPrimArrayToScalar; | |||
| extern const PrimitivePtr kPrimBroadcastShape; | |||
| extern const PrimitivePtr kPrimArrayMap; | |||
| extern const PrimitivePtr kPrimArrayReduce; | |||
| extern const PrimitivePtr kPrimShape; | |||
| extern const PrimitivePtr kPrimCast; | |||
| extern const PrimitivePtr kPrimConcat; | |||
| extern const PrimitivePtr kPrimSqueeze; | |||
| extern const PrimitivePtr kPrimTranspose; | |||
| extern const PrimitivePtr kPrimGatherV2; | |||
| extern const PrimitivePtr kPrimEmbeddingLookup; | |||
| extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; | |||
| extern const PrimitivePtr kPrimSize; | |||
| extern const PrimitivePtr kPrimArgMax; | |||
| extern const PrimitivePtr kPrimPack; | |||
| extern const PrimitivePtr kPrimUnpack; | |||
| extern const PrimitivePtr kPrimUnsortedSegmentMin; | |||
| extern const PrimitivePtr kPrimUnsortedSegmentSum; | |||
| extern const PrimitivePtr kPrimConcatOffset; | |||
| extern const PrimitivePtr kPrimReshape; | |||
| extern const PrimitivePtr kPrimTile; | |||
| extern const PrimitivePtr kPrimAddN; | |||
| extern const PrimitivePtr KPrimTransData; | |||
| extern const PrimitivePtr kPrimNMSWithMask; | |||
| extern const PrimitivePtr kPrimPad; | |||
| extern const PrimitivePtr kPrimArgMaxWithValue; | |||
| extern const PrimitivePtr kPrimRealDiv; | |||
| extern const PrimitivePtr kPrimSqrt; | |||
| extern const PrimitivePtr kPrimReciprocal; | |||
| extern const PrimitivePtr kPrimExpandDims; | |||
| // Maths | |||
| extern const PrimitivePtr kPrimTensorAdd; | |||
| extern const PrimitivePtr kPrimMatMul; | |||
| extern const PrimitivePtr kPrimBatchMatMul; | |||
| extern const PrimitivePtr kPrimMaximumGrad; | |||
| extern const PrimitivePtr kPrimMinimumGrad; | |||
| extern const PrimitivePtr kPrimReduceMean; | |||
| extern const PrimitivePtr kPrimReduceSum; | |||
| extern const PrimitivePtr kPrimReduceAll; | |||
| extern const PrimitivePtr kPrimReduceMax; | |||
| extern const PrimitivePtr kPrimReduceMin; | |||
| extern const PrimitivePtr kPrimNeg; | |||
| extern const PrimitivePtr kPrimSub; | |||
| extern const PrimitivePtr kPrimMul; | |||
| extern const PrimitivePtr kPrimRealDiv; | |||
| extern const PrimitivePtr kPrimMinimum; | |||
| extern const PrimitivePtr kPrimMaximum; | |||
| extern const PrimitivePtr kPrimSquare; | |||
| extern const PrimitivePtr kPrimSqrt; | |||
| extern const PrimitivePtr kPrimEqual; | |||
| extern const PrimitivePtr kPrimLess; | |||
| extern const PrimitivePtr kPrimLessEqual; | |||
| extern const PrimitivePtr kPrimCumSum; | |||
| extern const PrimitivePtr kPrimCumProd; | |||
| extern const PrimitivePtr kPrimSubscalar; | |||
| extern const PrimitivePtr kPrimInplaceAdd; | |||
| extern const PrimitivePtr kPrimInplaceSub; | |||
| extern const PrimitivePtr kPrimPow; | |||
| inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||
| inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | |||
| inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||
| inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||
| inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||
| inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||
| inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||
| inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); | |||
| inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||
| inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||
| inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||
| // NN | |||
| extern const PrimitivePtr kPrimFlatten; | |||
| extern const PrimitivePtr kPrimSoftmax; | |||
| extern const PrimitivePtr kPrimLogSoftmax; | |||
| extern const PrimitivePtr kPrimLogSoftmaxGrad; | |||
| extern const PrimitivePtr kPrimApplyCenteredRMSProp; | |||
| extern const PrimitivePtr kPrimTanh; | |||
| extern const PrimitivePtr kPrimTanhGrad; | |||
| extern const PrimitivePtr kPrimPooling; | |||
| extern const PrimitivePtr kPrimPoolingGrad; | |||
| extern const PrimitivePtr kPrimFusedBatchNorm; | |||
| extern const PrimitivePtr kPrimBatchNorm; | |||
| extern const PrimitivePtr kPrimBatchNormGrad; | |||
| extern const PrimitivePtr kPrimConv2D; | |||
| extern const PrimitivePtr kPrimMaxPool; | |||
| extern const PrimitivePtr kPrimMaxPoolGrad; | |||
| extern const PrimitivePtr kPrimAvgPoolGrad; | |||
| extern const PrimitivePtr kPrimFusedBatchNormGrad; | |||
| extern const PrimitivePtr kPrimReluGrad; | |||
| extern const PrimitivePtr kPrimConv2DBackpropInput; | |||
| extern const PrimitivePtr kPrimConv2DBackpropFilter; | |||
| extern const PrimitivePtr kPrimDepthwiseConv2dNative; | |||
| extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; | |||
| extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; | |||
| extern const PrimitivePtr kPrimBiasAddGrad; | |||
| extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; | |||
| extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; | |||
| extern const PrimitivePtr kPrimMomentum; | |||
| extern const PrimitivePtr kPrimApplyMomentum; | |||
| extern const PrimitivePtr kPrimLayerNorm; | |||
| extern const PrimitivePtr kPrimLayerNormGrad; | |||
| extern const PrimitivePtr kPrimLayerNormXBackprop; | |||
| extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; | |||
| extern const PrimitivePtr kPrimDropoutGenMask; | |||
| extern const PrimitivePtr kPrimDropoutDoMask; | |||
| extern const PrimitivePtr kPrimOneHot; | |||
| extern const PrimitivePtr kPrimGelu; | |||
| extern const PrimitivePtr kPrimGeluGrad; | |||
| extern const PrimitivePtr kPrimRelu; | |||
| extern const PrimitivePtr kPrimReluV2; | |||
| extern const PrimitivePtr kPrimActivation; | |||
| extern const PrimitivePtr kPrimZerosLike; | |||
| extern const PrimitivePtr kPrimFakeBprop; | |||
| extern const PrimitivePtr kPrimBpropCut; | |||
| extern const PrimitivePtr kPrimFakeQuantPerLayer; | |||
| extern const PrimitivePtr kPrimFakeQuantPerChannel; | |||
| extern const PrimitivePtr kPrimApplyRMSProp; | |||
| // Other Miscellaneous | |||
| extern const PrimitivePtr kPrimIdentity; | |||
| extern const PrimitivePtr kPrimPartial; | |||
| extern const PrimitivePtr kPrimJ; | |||
| extern const PrimitivePtr kPrimEnvSetItem; | |||
| extern const PrimitivePtr kPrimEnvGetItem; | |||
| extern const PrimitivePtr kPrimEnvAdd; | |||
| extern const PrimitivePtr kPrimMakeRefKey; | |||
| extern const PrimitivePtr kPrimMakeRef; | |||
| extern const PrimitivePtr kPrimGetRefKey; | |||
| extern const PrimitivePtr kPrimGetRefValue; | |||
| extern const PrimitivePtr kPrimGetRefOrigin; | |||
| extern const PrimitivePtr kPrimInsertGradientOf; | |||
| extern const PrimitivePtr kPrimHookBackward; | |||
| extern const PrimitivePtr kPrimPrintShapeType; | |||
| extern const PrimitivePtr kPrimPrint; | |||
| extern const PrimitivePtr kPrimSameTypeShape; | |||
| extern const PrimitivePtr kPrimCheckBprop; | |||
| extern const PrimitivePtr kPrimDepend; | |||
| extern const PrimitivePtr kPrimStateSetItem; | |||
| extern const PrimitivePtr kPrimScalarSummary; | |||
| extern const PrimitivePtr kPrimImageSummary; | |||
| extern const PrimitivePtr kPrimTensorSummary; | |||
| extern const PrimitivePtr kPrimHistogramSummary; | |||
| extern const PrimitivePtr kPrimBroadcastGradientArgs; | |||
| extern const PrimitivePtr kPrimControlDepend; | |||
| extern const PrimitivePtr kPrimIs_; | |||
| extern const PrimitivePtr kPrimIsNot; | |||
| extern const PrimitivePtr kPrimInDict; | |||
| extern const PrimitivePtr kPrimNotInDict; | |||
| extern const PrimitivePtr kPrimMixedPrecisionCast; | |||
| extern const PrimitivePtr kPrimIsConsant; | |||
| extern const PrimitivePtr kPrimEquivFormat; | |||
| extern const PrimitivePtr kPrimDebug; | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||
| inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||
| inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||
| inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||
| inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | |||
| inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | |||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | |||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | |||
| inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | |||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | |||
| inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | |||
| inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | |||
| inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | |||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | |||
| inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||
| inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||
| inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | |||
| inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | |||
| inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||
| inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||
| inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||
| inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||
| // Comm ops | |||
| extern const PrimitivePtr kPrimAllReduce; | |||
| extern const PrimitivePtr kPrimMirror; | |||
| extern const PrimitivePtr kPrimVirtualDiv; | |||
| extern const PrimitivePtr kPrimVirtualDataset; | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| // IndexedSlices | |||
| extern const PrimitivePtr kPrimMakeIndexedSlices; | |||
| extern const PrimitivePtr kPrimIndexedSlicesGetValues; | |||
| extern const PrimitivePtr kPrimIndexedSlicesGetIndices; | |||
| extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; | |||
| inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||
| inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||
| inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||
| inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||
| inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices"); | |||
| // SparseTensor | |||
| extern const PrimitivePtr kPrimMakeSparseTensor; | |||
| extern const PrimitivePtr kPrimSparseTensorGetValues; | |||
| extern const PrimitivePtr kPrimSparseTensorGetIndices; | |||
| extern const PrimitivePtr kPrimSparseTensorGetDenseShape; | |||
| inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); | |||
| inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues"); | |||
| inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | |||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | |||
| // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll | |||
| const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | |||
| @@ -305,22 +191,6 @@ const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; | |||
| // will be sunk(i.e. not unrolled) | |||
| const int MAX_FOR_LOOP_COUNT = 600; | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||
| : Primitive("S-Prim-" + name), function_(function) {} | |||
| ~DoSignaturePrimitive() override = default; | |||
| MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) | |||
| const ValuePtr function() const { return function_; } | |||
| private: | |||
| ValuePtr function_; | |||
| }; | |||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||
| class UnpackGraphPrimitive : public Primitive { | |||
| public: | |||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | |||
| @@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| (void)cnode_set.emplace(cnode); | |||
| } else { | |||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | |||
| @@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi | |||
| return cnode_dist; | |||
| } | |||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | |||
| << " operator_info: " << (cnode->operator_info() != nullptr); | |||
| << " operator_info: " << (operator_info != nullptr); | |||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||
| auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); | |||
| if (IsParallelCareNode(cnode) && (operator_info != nullptr)) { | |||
| auto cost = operator_info->GetForwardMemoryCostFromCNode(); | |||
| MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; | |||
| if (allreduce_graph_.NodeInGraph(cnode)) { | |||
| @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { | |||
| } | |||
| auto para_ptr = node_ptr->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(para_ptr); | |||
| auto layout_ptr = para_ptr->tensor_layout(); | |||
| auto layout_ptr = para_ptr->GetUserData<TensorLayout>(); | |||
| if (layout_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "layout_ptr is nullptr!"; | |||
| return FAILED; | |||
| @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| for (auto para : graph_params) { | |||
| std::string name = std::static_pointer_cast<Parameter>(para)->name(); | |||
| std::shared_ptr<parallel::TensorLayout> tensor_layout = std::static_pointer_cast<Parameter>(para)->tensor_layout(); | |||
| auto tensor_layout = para->GetUserData<parallel::TensorLayout>(); | |||
| if (tensor_layout == nullptr) { | |||
| MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | |||
| } else { | |||
| @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto distributed_operation_info = cnode->operator_info(); | |||
| auto distributed_operation_info = cnode->GetUserData<OperatorInfo>(); | |||
| if (distributed_operation_info != nullptr) { | |||
| auto strategyPtr = distributed_operation_info->strategy(); | |||
| if (strategyPtr != nullptr) { | |||
| @@ -163,6 +163,9 @@ class OperatorInfo { | |||
| const std::string &type() const { return type_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| // Key for user data. | |||
| constexpr static char key[] = "OpInfo"; | |||
| protected: | |||
| // needed by rec_parser | |||
| std::string type_; | |||
| @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| (void)cnode->set_operator_info(operator_info); | |||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| (void)cnode->set_operator_info(operator_info); | |||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | |||
| << " does not match the Prim: " << prim->name(); | |||
| } | |||
| (void)cnode->set_operator_info(current_op_ptr); | |||
| cnode->SetUserData<OperatorInfo>(current_op_ptr); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | |||
| @@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| size_t edge_count = 0; | |||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | |||
| bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| @@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | |||
| while (bool_result) { | |||
| if (IsAutoParallelCareNode(prev_cnode)) { | |||
| std::string edge_name = | |||
| prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); | |||
| auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>(); | |||
| std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); | |||
| // If the edge between these two operators already has been added, then the edge will not be added again. | |||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | |||
| break; | |||
| @@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if (follow_strategy) { | |||
| // Redistribution in not allowed on the edge. | |||
| // Elementwise operators have the same strategy as their previous operators. | |||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(), | |||
| output_index, i - 1, false, true); | |||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); | |||
| } else { | |||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(), | |||
| output_index, i - 1, false); | |||
| edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); | |||
| } | |||
| // Init costs for this edge | |||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | |||
| } | |||
| cnode->operator_info()->AddPrevEdge(edge_ptr); | |||
| prev_cnode->operator_info()->AddSuccEdge(edge_ptr); | |||
| entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); | |||
| MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " | |||
| << cnode->operator_info()->name(); | |||
| node_op_info->AddPrevEdge(edge_ptr); | |||
| prev_op_info->AddSuccEdge(edge_ptr); | |||
| entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); | |||
| MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " | |||
| << node_op_info->name(); | |||
| edge_count++; | |||
| break; | |||
| @@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); | |||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); | |||
| } | |||
| MS_LOG(INFO) << "Constructing edges for cost graph ends."; | |||
| @@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| for (auto &target : target_set) { | |||
| auto target_cnode = target.first->cast<CNodePtr>(); | |||
| auto input_index = target.second; | |||
| (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); | |||
| (void)target_without_duplicate.insert(std::to_string(input_index) + | |||
| target_cnode->GetUserData<OperatorInfo>()->name()); | |||
| } | |||
| if (target_without_duplicate.size() <= 1) { | |||
| continue; | |||
| @@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto target_cnode = target.first->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | |||
| auto input_index = target.second; | |||
| auto target_op_info = target_cnode->GetUserData<OperatorInfo>(); | |||
| std::string edge_name = | |||
| std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); | |||
| std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); | |||
| // If the edge between these two operators already has been added, then the edge will not be added again. | |||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { | |||
| continue; | |||
| } | |||
| std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>( | |||
| edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); | |||
| std::shared_ptr<Edge> edge_ptr = | |||
| std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); | |||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | |||
| } | |||
| target_cnode->operator_info()->AddPrevEdge(edge_ptr); | |||
| target_op_info->AddPrevEdge(edge_ptr); | |||
| tmp_identity_ptr->AddSuccEdge(edge_ptr); | |||
| entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); | |||
| entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr); | |||
| MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " | |||
| << target_cnode->operator_info()->name(); | |||
| << target_op_info->name(); | |||
| add_identity_edge = true; | |||
| } | |||
| if (new_identity && add_identity_edge) { | |||
| @@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| if (operator_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||
| } | |||
| if (prim->name() != RESHAPE) { | |||
| return false; | |||
| } | |||
| return true; | |||
| return (prim->name() == RESHAPE); | |||
| } | |||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| @@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||
| *pre_operator_info = cnode->operator_info(); | |||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||
| *pre_operator_info = node_op_info; | |||
| *out_index = 0; | |||
| return true; | |||
| } | |||
| @@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||
| } | |||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { | |||
| *pre_operator_info = pre_cnode->operator_info(); | |||
| auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||
| *pre_operator_info = pre_op_info; | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||
| auto op_info = use_apply->GetUserData<OperatorInfo>(); | |||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||
| *next_operator_info = use_apply->operator_info(); | |||
| *next_operator_info = op_info; | |||
| *in_index = node_pair.second - 1; | |||
| return true; | |||
| } | |||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||
| << " " << (use_apply->operator_info() != nullptr); | |||
| << " " << (op_info != nullptr); | |||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||
| return true; | |||
| @@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| int32_t out_index = 0; | |||
| OperatorInfoPtr pre_operator_info; | |||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | |||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| if (pre_node->isa<Parameter>()) { | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| reshape_info->SetCostForReshapeWithParameter(); | |||
| pre_operator_info = reshape_info; | |||
| @@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| // set input_layout and output_layout for reshape. | |||
| // init reshape and set cost for each input_layout and output_layout. | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| reshape_info->set_pre_operator_name(pre_operator_info->name()); | |||
| reshape_info->set_pre_operator_index(out_index); | |||
| @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { | |||
| if (!IsParallelCareNode(node)) { | |||
| return nullptr; | |||
| } | |||
| OperatorInfoPtr distribute_operator = node->operator_info(); | |||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | |||
| } | |||
| @@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||
| if (prim->name() == GET_NEXT) { | |||
| return true; | |||
| } | |||
| if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { | |||
| if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| @@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { | |||
| if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) { | |||
| Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | |||
| pre_node); | |||
| } else { | |||
| @@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||
| void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(next_node); | |||
| OperatorInfoPtr op_info = next_node->operator_info(); | |||
| OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>(); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| // If the shape of tensor is [] or [1], no need to split it. | |||
| @@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||
| // step1:get graph manager distribute_operator | |||
| OperatorInfoPtr distribute_operator = node->operator_info(); | |||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | |||
| } | |||
| @@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||
| (void)prim->SetAttrs(attrs); | |||
| } | |||
| if (index == replace_op.size() - 1) { | |||
| (void)replace_node->set_operator_info(node->operator_info()); | |||
| replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>()); | |||
| } | |||
| replace_node->set_in_forward_flag(true); | |||
| replace_input[0]->set_scope(scope); | |||
| @@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { | |||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||
| auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||
| if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||
| pre_node = pre_cnode->input(1); | |||
| } | |||
| @@ -1204,7 +1204,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| return node_pair; | |||
| } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | |||
| return FindParallelCareNode(node_pair.first); | |||
| @@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | |||
| CNodePtr cnode = res.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| OperatorInfoPtr distribute_operator = cnode->operator_info(); | |||
| OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | |||
| } | |||
| @@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | |||
| parameter_ptr->set_tensor_layout(std::make_shared<TensorLayout>(tensor_layout)); | |||
| parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||
| } | |||
| void CoverSliceShape(const FuncGraphPtr &root) { | |||
| @@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| if (found_be_cloned_parameter) { | |||
| // set the shape and tensor layout for cloned parameter | |||
| cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); | |||
| cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>()); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | |||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | |||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | |||
| @@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (*operator_).set_outputs_dtype(cnode->Type()); | |||
| (*operator_).set_cnode(cnode); | |||
| if (prim->name() == RESHAPE) { | |||
| (void)cnode->set_operator_info(operator_); | |||
| cnode->SetUserData<OperatorInfo>(operator_); | |||
| continue; | |||
| } | |||
| // load strategy checkpoint | |||
| @@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if (operator_->Init(strategyPtr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | |||
| } | |||
| (void)cnode->set_operator_info(operator_); | |||
| cnode->SetUserData<OperatorInfo>(operator_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | |||
| } | |||
| @@ -1542,13 +1542,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) { | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||
| if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) { | |||
| MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | |||
| auto layout = GetInputLayoutFromCNode(node_pair); | |||
| return std::make_shared<TensorLayout>(layout); | |||
| } | |||
| MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||
| << " " << (use_apply->operator_info() != nullptr); | |||
| << " " << use_apply->HasUserData<OperatorInfo>(); | |||
| auto layout_ptr = FindNextLayout(use_apply); | |||
| if (layout_ptr) { | |||
| @@ -1580,7 +1580,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return nullptr; | |||
| } | |||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | |||
| if (!layout_ptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | |||
| @@ -1624,7 +1624,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return nullptr; | |||
| } | |||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | |||
| if (!layout_ptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | |||
| @@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| if (operator_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||
| } | |||
| @@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| // return -> cast | |||
| if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { | |||
| if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| @@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { | |||
| return ret; | |||
| } | |||
| OperatorInfoPtr operator_info = loss_cnode->operator_info(); | |||
| OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>(); | |||
| MS_EXCEPTION_IF_NULL(operator_info); | |||
| TensorInfo loss_grad_tensor_info; | |||
| size_t op_output_size = operator_info->outputs_tensor_info().size(); | |||
| @@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||
| if (sens_tensor_node->isa<Parameter>()) { | |||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | |||
| MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | |||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| } | |||
| MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | |||
| return; | |||
| @@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| sens_tensor_node->set_abstract(cloned_abstract); | |||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | |||
| sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | |||
| @@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| if (operator_info) { | |||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| @@ -83,6 +83,9 @@ class TensorLayout { | |||
| TensorLayout SqueezeShape() const; | |||
| // Key for user data. | |||
| constexpr static char key[] = "TLayout"; | |||
| private: | |||
| std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement( | |||
| const Arrangement &expanded_shape) const; | |||
| @@ -0,0 +1,160 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ | |||
| #define MINDSPORE_CORE_OPERATOR_OPS_H_ | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace prim { | |||
| // Maths | |||
| inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||
| inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | |||
| inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | |||
| inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | |||
| inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | |||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||
| inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | |||
| inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||
| inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | |||
| inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | |||
| inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); | |||
| inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||
| inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||
| inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||
| inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||
| inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||
| inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | |||
| inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | |||
| // Statements | |||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||
| inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||
| inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | |||
| inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||
| inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||
| inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub"); | |||
| inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select"); | |||
| inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||
| inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple"); | |||
| inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict"); | |||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||
| inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg"); | |||
| inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice"); | |||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||
| inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem"); | |||
| inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem"); | |||
| inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem"); | |||
| inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem"); | |||
| inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||
| inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| inline const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr"); | |||
| inline const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len"); | |||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| inline const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len"); | |||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||
| inline const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape"); | |||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||
| inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index"); | |||
| inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | |||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||
| inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg"); | |||
| // Debug ops | |||
| inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||
| inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||
| inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||
| inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||
| inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend"); | |||
| inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial"); | |||
| inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||
| inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem"); | |||
| inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem"); | |||
| inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||
| inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||
| inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend"); | |||
| inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); | |||
| inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); | |||
| inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||
| inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||
| : Primitive("S-Prim-" + name), function_(function) {} | |||
| ~DoSignaturePrimitive() override = default; | |||
| MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) | |||
| const ValuePtr function() const { return function_; } | |||
| private: | |||
| ValuePtr function_; | |||
| }; | |||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPERATOR_OPS_H_ | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_USER_DATA_H_ | |||
| #define MINDSPORE_CORE_USER_DATA_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| namespace mindspore { | |||
| class UserData { | |||
| public: | |||
| template <typename T> | |||
| void set(const std::string &key, const std::shared_ptr<T> &value) { | |||
| if (value == nullptr) { | |||
| data_.erase(key); | |||
| } else { | |||
| data_.insert_or_assign(key, value); | |||
| } | |||
| } | |||
| template <typename T> | |||
| std::shared_ptr<T> get(const std::string &key) const { | |||
| auto iter = data_.find(key); | |||
| if (iter == data_.end()) { | |||
| return nullptr; | |||
| } | |||
| return std::static_pointer_cast<T>(iter->second); | |||
| } | |||
| bool has(const std::string &key) const { return data_.find(key) != data_.end(); } | |||
| private: | |||
| std::map<std::string, std::shared_ptr<void>> data_; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_USER_DATA_H_ | |||
| @@ -26,7 +26,6 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| @@ -27,6 +27,7 @@ | |||
| #include <utility> | |||
| #include "base/base.h" | |||
| #include "base/user_data.h" | |||
| #include "ir/kernel_info_dev.h" | |||
| #include "ir/scope.h" | |||
| #include "debug/info.h" | |||
| @@ -41,12 +42,6 @@ | |||
| // ANode: Atomic Node | |||
| // CNode: Complex Node | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class TensorLayout; | |||
| class OperatorInfo; | |||
| } // namespace parallel | |||
| using OperatorInfoPtr = std::shared_ptr<parallel::OperatorInfo>; | |||
| namespace abstract { | |||
| class BaseShape; | |||
| class AbstractBase; | |||
| @@ -157,6 +152,31 @@ class AnfNode : public Base { | |||
| } | |||
| size_t seen_{0}; | |||
| template <typename T> | |||
| void SetUserData(const std::string &key, const std::shared_ptr<T> &value) { | |||
| user_data_.set<T>(key, value); | |||
| } | |||
| template <typename T> | |||
| void SetUserData(const std::shared_ptr<T> &value) { | |||
| user_data_.set<T>(T::key, value); | |||
| } | |||
| template <typename T> | |||
| std::shared_ptr<T> GetUserData(const std::string &key) const { | |||
| return user_data_.get<T>(key); | |||
| } | |||
| template <typename T> | |||
| std::shared_ptr<T> GetUserData() const { | |||
| return user_data_.get<T>(T::key); | |||
| } | |||
| bool HasUserData(const std::string &key) const { return user_data_.has(key); } | |||
| template <typename T> | |||
| bool HasUserData() const { return user_data_.has(T::key); } | |||
| protected: | |||
| // Hold a weak ref to Graph as Graph also hold ref to AnfNode. | |||
| // Otherwise, func_graph_ and AnfNode will make a reference cycle. | |||
| @@ -170,6 +190,7 @@ class AnfNode : public Base { | |||
| std::hash<const AnfNode *> hash_; | |||
| ScopePtr scope_; | |||
| KernelInfoDevicePtr kernel_info_; | |||
| UserData user_data_; | |||
| }; | |||
| // CNode represents the complex node with a set of arguments. | |||
| @@ -212,9 +233,6 @@ class CNode : public AnfNode { | |||
| std::string DebugString(int recursive_level = 1) const override; | |||
| std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | |||
| OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); | |||
| OperatorInfoPtr operator_info() { return operator_info_; } | |||
| void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } | |||
| bool in_forward_flag() const { return in_forward_flag_; } | |||
| @@ -224,7 +242,6 @@ class CNode : public AnfNode { | |||
| std::vector<AnfNodePtr> inputs_; | |||
| VarPtr func_graph_as_var_; | |||
| bool stop_gradient_; | |||
| OperatorInfoPtr operator_info_ = nullptr; | |||
| bool in_forward_flag_ = false; | |||
| }; | |||
| @@ -244,7 +261,7 @@ class ANode : public AnfNode { | |||
| class Parameter : public ANode { | |||
| public: | |||
| explicit Parameter(const FuncGraphPtr &func_graph) | |||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} | |||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} | |||
| ~Parameter() override = default; | |||
| MS_DECLARE_PARENT(Parameter, ANode); | |||
| @@ -261,11 +278,6 @@ class Parameter : public ANode { | |||
| } | |||
| ParamValuePtr default_param() const { return default_param_; } | |||
| std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; } | |||
| void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) { | |||
| tensor_layout_ = tensor_layout; | |||
| } | |||
| bool operator==(const AnfNode &other) const override { | |||
| if (!other.isa<Parameter>()) { | |||
| return false; | |||
| @@ -281,7 +293,6 @@ class Parameter : public ANode { | |||
| std::string name_; | |||
| bool has_default_; | |||
| ParamValuePtr default_param_; | |||
| std::shared_ptr<parallel::TensorLayout> tensor_layout_; | |||
| }; | |||
| using ParameterPtr = std::shared_ptr<Parameter>; | |||
| @@ -23,8 +23,7 @@ | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "debug/label.h" | |||
| namespace mindspore { | |||
| @@ -37,18 +36,6 @@ std::string AnfNode::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | |||
| } | |||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||
| if (operator_info_ != nullptr) { | |||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | |||
| << ", using the new one: " << operator_info->name(); | |||
| auto old_ptr = operator_info_; | |||
| operator_info_ = operator_info; | |||
| return old_ptr; | |||
| } | |||
| operator_info_ = operator_info; | |||
| return nullptr; | |||
| } | |||
| std::string CNode::fullname_with_scope() { | |||
| // if full name is set, return its name immediately | |||
| if (!fullname_with_scope_.empty()) { | |||
| @@ -24,7 +24,6 @@ | |||
| #include "debug/trace.h" | |||
| #include "ir/manager.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/convert_utils_base.h" | |||
| @@ -20,7 +20,7 @@ | |||
| #include "ir/manager.h" | |||
| #include "ir/param_value.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/profile.h" | |||
| @@ -22,7 +22,7 @@ | |||
| #include "ir/manager.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| @@ -26,7 +26,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "utils/profile.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| @@ -17,10 +17,8 @@ | |||
| */ | |||
| #include "ir/meta_func_graph.h" | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "pipeline/jit/static_analysis/abstract_function.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "frontend/operator/ops.h" | |||
| // namespace to support intermediate representation definition | |||
| namespace mindspore { | |||
| @@ -22,9 +22,9 @@ | |||
| #include <tuple> | |||
| #include <vector> | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| /// | |||
| @@ -25,7 +25,6 @@ | |||
| #include "ir/dtype/type.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "utils/base_ref_extends.h" | |||
| namespace mindspore { | |||
| @@ -18,7 +18,6 @@ | |||
| #include <mutex> | |||
| #include <utility> | |||
| #include "ir/signature.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "./common.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| @@ -28,7 +28,6 @@ | |||
| #include <type_traits> | |||
| #include <typeinfo> | |||
| #include "runtime/device/device_address.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { | |||
| StrategyPtr strategyPtr; | |||
| std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | |||
| node->set_operator_info(matmul_info); | |||
| node->SetUserData<OperatorInfo>(matmul_info); | |||
| std::string name_expect = "MatMulInfo00"; | |||
| std::string name_test = matmul_info->name(); | |||
| ASSERT_EQ(name_expect, name_test); | |||
| @@ -525,8 +525,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { | |||
| std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | |||
| OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | |||
| matmul_info->Init(strategyPtr); | |||
| node->set_operator_info(matmul_info); | |||
| OperatorInfoPtr distribute_operator_pre = node->operator_info(); | |||
| node->SetUserData<OperatorInfo>(matmul_info); | |||
| OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>(); | |||
| TensorLayout tensorlayout_e; | |||
| std::vector<int32_t> array = {64, 64}; | |||
| TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | |||