| @@ -590,7 +590,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n | |||||
| op_nodes = nodes; | op_nodes = nodes; | ||||
| } else { | } else { | ||||
| // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, | // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, | ||||
| // so a new graph generation should be done (beacuse they may in the main graph!). | |||||
| // so a new graph generation should be done (because they may in the main graph!). | |||||
| // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. | // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. | ||||
| MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; | MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; | ||||
| } | } | ||||
| @@ -717,7 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimMinimumGrad, | prim::kPrimMinimumGrad, | ||||
| prim::kPrimGkDropout, | prim::kPrimGkDropout, | ||||
| prim::kPrimDropoutGrad, | prim::kPrimDropoutGrad, | ||||
| prim::kPrimSoftMax, | |||||
| prim::kPrimSoftmax, | |||||
| prim::kPrimLayerNorm, | prim::kPrimLayerNorm, | ||||
| prim::kPrimLayerNormGrad, | prim::kPrimLayerNormGrad, | ||||
| #endif | #endif | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "abstract/utils.h" | #include "abstract/utils.h" | ||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| @@ -148,7 +148,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o | |||||
| } | } | ||||
| } | } | ||||
| tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); | tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); | ||||
| // if in paynative mode,data only copyed to host when user want to print data | |||||
| // if in paynative mode,data only copied to host when user want to print data | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | ||||
| @@ -1313,8 +1313,8 @@ void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector< | |||||
| input_abstracts.emplace_back(abstract); | input_abstracts.emplace_back(abstract); | ||||
| } | } | ||||
| auto prim = AnfAlgo::GetCNodePrimitive(node); | auto prim = AnfAlgo::GetCNodePrimitive(node); | ||||
| if (prim->isa<PrimitiveC>()) { | |||||
| auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>(); | |||||
| if (prim->isa<ops::PrimitiveC>()) { | |||||
| auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_c); | MS_EXCEPTION_IF_NULL(prim_c); | ||||
| auto abstract = prim_c->Infer(input_abstracts); | auto abstract = prim_c->Infer(input_abstracts); | ||||
| node->set_abstract(abstract); | node->set_abstract(abstract); | ||||
| @@ -1835,7 +1835,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // PS embeddingLookup cache check. | // PS embeddingLookup cache check. | ||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " | |||||
| MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in " | |||||
| "parameter server training mode."; | "parameter server training mode."; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | ||||
| @@ -5,15 +5,14 @@ add_subdirectory(gvar) | |||||
| message("************ build core ***************") | message("************ build core ***************") | ||||
| file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| "abstract/*.cc" | "abstract/*.cc" | ||||
| "base/*.cc" | "base/*.cc" | ||||
| "c_ops/*.cc" | |||||
| "ops/*.cc" | |||||
| "ir/*.cc" | "ir/*.cc" | ||||
| "utils/*.cc" | "utils/*.cc" | ||||
| "load_mindir/*.cc" | "load_mindir/*.cc" | ||||
| ) | |||||
| ) | |||||
| if (CMAKE_SYSTEM_NAME MATCHES "Windows") | if (CMAKE_SYSTEM_NAME MATCHES "Windows") | ||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF") | ||||
| add_compile_definitions(BUILDING_DLL) | add_compile_definitions(BUILDING_DLL) | ||||
| @@ -92,7 +92,9 @@ inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelS | |||||
| inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | ||||
| // Arrays | // Arrays | ||||
| inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo"); | |||||
| inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | ||||
| inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK"); | |||||
| inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | ||||
| inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | ||||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | ||||
| @@ -100,17 +102,24 @@ inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_ | |||||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | ||||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | ||||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | ||||
| inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze"); | |||||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | ||||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | ||||
| inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD"); | inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD"); | ||||
| inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather"); | |||||
| inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND"); | |||||
| inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2"); | inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2"); | ||||
| inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense"); | |||||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | ||||
| inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice"); | |||||
| inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape"); | inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape"); | ||||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | ||||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | ||||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | ||||
| inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | ||||
| inline const PrimitivePtr kPrimArgMin = std::make_shared<Primitive>("Argmin"); | |||||
| inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | ||||
| inline const PrimitivePtr kPrimUnpack = std::make_shared<Primitive>("Unpack"); | |||||
| inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax"); | inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax"); | ||||
| inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | ||||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | ||||
| @@ -121,7 +130,10 @@ inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCac | |||||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | ||||
| inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits"); | inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits"); | ||||
| inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | ||||
| inline const PrimitivePtr kPrimDynamicAssign = std::make_shared<Primitive>("DynamicAssign"); | |||||
| inline const PrimitivePtr kPrimPadAndShift = std::make_shared<Primitive>("PadAndShift"); | |||||
| inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | ||||
| inline const PrimitivePtr kPrimSliceFusion = std::make_shared<Primitive>("SliceFusion"); | |||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | ||||
| @@ -141,16 +153,38 @@ inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("Scat | |||||
| inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform"); | inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform"); | ||||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | ||||
| inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | ||||
| inline const PrimitivePtr kPrimDepthToSpace = std::make_shared<Primitive>("DepthToSpace"); | |||||
| inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("BatchToSpace"); | |||||
| inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch"); | |||||
| inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd"); | |||||
| inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape"); | |||||
| inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference"); | |||||
| inline const PrimitivePtr kPrimSpaceToBatchND = std::make_shared<Primitive>("SpaceToBatchND"); | |||||
| inline const PrimitivePtr kPrimBatchToSpaceND = std::make_shared<Primitive>("BatchToSpaceND"); | |||||
| inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2"); | |||||
| inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence"); | |||||
| inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank"); | |||||
| inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | |||||
| inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram"); | |||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax"); | |||||
| inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop"); | |||||
| inline const PrimitivePtr kPrimFlattenGrad = std::make_shared<Primitive>("FlattenGrad"); | |||||
| inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropy = std::make_shared<Primitive>("SparseSoftmaxCrossEntropy"); | |||||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | ||||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | ||||
| inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm"); | |||||
| inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan"); | |||||
| inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan"); | |||||
| inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin"); | |||||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | ||||
| inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | ||||
| inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | ||||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | ||||
| inline const PrimitivePtr kPrimROIPooling = std::make_shared<Primitive>("ROIPooling"); | |||||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | ||||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | ||||
| inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax"); | inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax"); | ||||
| @@ -164,6 +198,9 @@ inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("Fu | |||||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | ||||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | ||||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | ||||
| inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection"); | |||||
| inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose"); | |||||
| inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput"); | |||||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | ||||
| inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx"); | inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx"); | ||||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | ||||
| @@ -175,20 +212,33 @@ inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive> | |||||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | ||||
| inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize"); | |||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | ||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | ||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | ||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | ||||
| inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess"); | |||||
| inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | ||||
| inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad"); | |||||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | ||||
| inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad"); | |||||
| inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy"); | |||||
| inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad"); | |||||
| inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss"); | |||||
| inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad"); | |||||
| inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | ||||
| std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | ||||
| inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits = | |||||
| std::make_shared<Primitive>("SigmoidCrossEntropyWithLogits"); | |||||
| inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogitsGrad = | |||||
| std::make_shared<Primitive>("SigmoidCrossEntropyWithLogitsGrad"); | |||||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | ||||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | ||||
| inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | ||||
| inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | ||||
| inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | ||||
| inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn"); | |||||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | ||||
| inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | ||||
| inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | ||||
| @@ -204,13 +254,16 @@ inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad" | |||||
| inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu"); | inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu"); | ||||
| inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad"); | inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad"); | ||||
| inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | ||||
| inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("ELU"); | |||||
| inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | ||||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | ||||
| inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | |||||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | ||||
| inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | ||||
| inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | ||||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | ||||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | ||||
| inline const PrimitivePtr kPrimFakeQuantWithMinMaxVars = std::make_shared<Primitive>("FakeQuantWithMinMaxVars"); | |||||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | ||||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | ||||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | ||||
| @@ -219,6 +272,8 @@ inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive | |||||
| inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>(kSGD); | inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>(kSGD); | ||||
| inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum"); | inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum"); | ||||
| inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove"); | inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove"); | ||||
| inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize"); | |||||
| inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitive>("CustomExtractFeatures"); | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -233,6 +288,12 @@ inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGathe | |||||
| inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | ||||
| inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async"); | inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async"); | ||||
| inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill"); | inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill"); | ||||
| // Quant ops | |||||
| inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold"); | |||||
| inline const PrimitivePtr kPrimFakeQuantWithMinMaxVarsPerChannel = | |||||
| std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel"); | |||||
| // Control ops | |||||
| inline const PrimitivePtr kPrimMerge = std::make_shared<Primitive>("Merge"); | |||||
| // RowTensor | // RowTensor | ||||
| inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | ||||
| inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues"); | inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues"); | ||||
| @@ -245,12 +306,22 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv | |||||
| inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | ||||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | ||||
| // TensorList | |||||
| inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor"); | |||||
| inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve"); | |||||
| inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Primitive>("TensorListStack"); | |||||
| inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem"); | |||||
| // Maths | // Maths | ||||
| inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil"); | |||||
| inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | ||||
| inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add"); | |||||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | ||||
| inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag"); | |||||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | ||||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | ||||
| inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | ||||
| inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce"); | |||||
| inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | ||||
| inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | ||||
| inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | ||||
| @@ -258,9 +329,12 @@ inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAn | |||||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | ||||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | ||||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | ||||
| inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin"); | |||||
| inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | ||||
| inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | ||||
| inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div"); | inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div"); | ||||
| inline const PrimitivePtr kPrimMod = std::make_shared<Primitive>("Mod"); | |||||
| inline const PrimitivePtr kPrimFloor = std::make_shared<Primitive>("Floor"); | |||||
| inline const PrimitivePtr kPrimDivNoNan = std::make_shared<Primitive>("DivNoNan"); | inline const PrimitivePtr kPrimDivNoNan = std::make_shared<Primitive>("DivNoNan"); | ||||
| inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); | ||||
| inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); | ||||
| @@ -271,6 +345,7 @@ inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscala | |||||
| inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | ||||
| inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | ||||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | ||||
| inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power"); | |||||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | ||||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | ||||
| inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); | inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); | ||||
| @@ -280,10 +355,17 @@ inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs"); | |||||
| inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round"); | inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round"); | ||||
| inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp"); | ||||
| inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | ||||
| inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd"); | |||||
| inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr"); | |||||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | ||||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | ||||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | ||||
| inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot"); | |||||
| inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression"); | |||||
| inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | ||||
| inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv"); | |||||
| inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); | |||||
| inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | |||||
| // Statements | // Statements | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>(kReturn); | inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>(kReturn); | ||||
| @@ -309,6 +391,7 @@ inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>( | |||||
| inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | ||||
| // Debug ops | // Debug ops | ||||
| inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert"); | |||||
| inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | ||||
| inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | ||||
| inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | ||||
| @@ -334,6 +417,13 @@ inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||||
| inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | ||||
| inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | ||||
| inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | ||||
| inline const PrimitivePtr kPrimLshProjection = std::make_shared<Primitive>("LshProjection"); | |||||
| inline const PrimitivePtr kPrimHashtableLookup = std::make_shared<Primitive>("HashtableLookup"); | |||||
| inline const PrimitivePtr kPrimCustomPredict = std::make_shared<Primitive>("CustomPredict"); | |||||
| inline const PrimitivePtr kPrimStack = std::make_shared<Primitive>("Stack"); | |||||
| inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox"); | |||||
| inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast"); | |||||
| inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While"); | |||||
| // Structures | // Structures | ||||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | ||||
| @@ -356,7 +446,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_ | |||||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | ||||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | ||||
| // Other primitve not used by backend but used in core; | |||||
| // Other primitive not used by backend but used in core; | |||||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | ||||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | ||||
| @@ -366,6 +456,43 @@ inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict | |||||
| // GraphKernel ops | // GraphKernel ops | ||||
| inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign"); | inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign"); | ||||
| // Only used in lite | |||||
| inline const PrimitivePtr kPrimLeakyRelu = std::make_shared<Primitive>("LeakyRelu"); | |||||
| inline const PrimitivePtr kPrimConstant = std::make_shared<Primitive>("Constant"); | |||||
| inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | |||||
| inline const PrimitivePtr kPrimLocalResponseNormalization = std::make_shared<Primitive>("LocalResponseNormalization"); | |||||
| inline const PrimitivePtr kPrimFftReal = std::make_shared<Primitive>("FftReal"); | |||||
| inline const PrimitivePtr kPrimMfcc = std::make_shared<Primitive>("Mfcc"); | |||||
| inline const PrimitivePtr kPrimRfft = std::make_shared<Primitive>("Rfft"); | |||||
| inline const PrimitivePtr kPrimFftImag = std::make_shared<Primitive>("FftImag"); | |||||
| inline const PrimitivePtr kPrimSkipGram = std::make_shared<Primitive>("SkipGram"); | |||||
| inline const PrimitivePtr kPrimConv2DFusion = std::make_shared<Primitive>("Conv2DFusion"); | |||||
| inline const PrimitivePtr kPrimConv2dTransposeFusion = std::make_shared<Primitive>("Conv2dTransposeFusion"); | |||||
| inline const PrimitivePtr kPrimDepthWiseConv2DFusion = std::make_shared<Primitive>("DepthWiseConv2DFusion"); | |||||
| inline const PrimitivePtr kPrimAddFusion = std::make_shared<Primitive>("AddFusion"); | |||||
| inline const PrimitivePtr kPrimScaleFusion = std::make_shared<Primitive>("ScaleFusion"); | |||||
| inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusion"); | |||||
| inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion"); | |||||
| inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid"); | |||||
| inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip"); | |||||
| inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh"); | |||||
| inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion = | |||||
| std::make_shared<Primitive>("DepthWiseConv2DTransposeFusion"); | |||||
| inline const PrimitivePtr kPrimArgMinFusion = std::make_shared<Primitive>("ArgMinFusion"); | |||||
| inline const PrimitivePtr kPrimArgMaxFusion = std::make_shared<Primitive>("ArgMaxFusion"); | |||||
| inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("SpaceToDepth"); | |||||
| inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion"); | |||||
| inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion"); | |||||
| inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize"); | |||||
| inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose"); | |||||
| inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue"); | |||||
| inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If"); | |||||
| inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion"); | |||||
| inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion"); | |||||
| inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation"); | |||||
| inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion"); | |||||
| inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); | |||||
| inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| @@ -1,19 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/abs.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameAbs, Abs); | |||||
| } // namespace mindspore | |||||
| @@ -1,51 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/apply_momentum.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| void ApplyMomentum::Init(bool use_nesterov, bool use_locking, float gradient_scale) { | |||||
| this->set_use_nesterov(use_nesterov); | |||||
| this->set_use_locking(use_locking); | |||||
| this->set_gradient_scale(gradient_scale); | |||||
| } | |||||
| void ApplyMomentum::set_use_nesterov(bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||||
| void ApplyMomentum::set_use_locking(bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||||
| void ApplyMomentum::set_gradient_scale(float gradient_scale) { | |||||
| this->AddAttr(kGradientScale, MakeValue(gradient_scale)); | |||||
| } | |||||
| bool ApplyMomentum::get_use_nesterov() const { | |||||
| auto value_ptr = GetAttr(kUseNesterov); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool ApplyMomentum::get_use_locking() const { | |||||
| auto value_ptr = GetAttr(kUseLocking); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| float ApplyMomentum::get_gradient_scale() { | |||||
| auto value_ptr = GetAttr(kGradientScale); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/assign_add.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/atan.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameAtan, Atan); | |||||
| } // namespace mindspore | |||||
| @@ -1,54 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/audio_spectrogram.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| void AudioSpectrogram::set_window_size(const int64_t &window_size) { | |||||
| this->AddAttr(kWindowSize, MakeValue(window_size)); | |||||
| } | |||||
| int64_t AudioSpectrogram::get_window_size() const { | |||||
| auto value_ptr = GetAttr(kWindowSize); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void AudioSpectrogram::set_stride(const int64_t &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||||
| int64_t AudioSpectrogram::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void AudioSpectrogram::set_mag_square(const bool &mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); } | |||||
| bool AudioSpectrogram::get_mag_square() const { | |||||
| auto value_ptr = GetAttr(kMagSquare); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void AudioSpectrogram::Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square) { | |||||
| this->set_window_size(window_size); | |||||
| this->set_stride(stride); | |||||
| this->set_mag_square(mag_square); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram); | |||||
| } // namespace mindspore | |||||
| @@ -1,59 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "c_ops/batch_norm.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| void BatchNorm::Init(bool is_training, float epsilon, const Format &format) { | |||||
| set_is_training(is_training); | |||||
| set_epsilon(epsilon); | |||||
| set_format(format); | |||||
| } | |||||
| void BatchNorm::set_is_training(bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||||
| void BatchNorm::set_epsilon(float epsilon) { | |||||
| CheckAndConvertUtils::CheckInRange(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name()); | |||||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||||
| } | |||||
| void BatchNorm::set_format(const Format &format) { | |||||
| int64_t f = format; | |||||
| this->AddAttr(kFormat, MakeValue(f)); | |||||
| } | |||||
| bool BatchNorm::get_is_trainging() { | |||||
| auto value_ptr = GetAttr(kIsTraining); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| float BatchNorm::get_epsilon() { | |||||
| auto value_ptr = GetAttr(kEpsilon); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| Format BatchNorm::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/batch_norm_fold.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold); | |||||
| } // namespace mindspore | |||||
| @@ -1,31 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/binary_cross_entropy_grad.h" | |||||
| namespace mindspore { | |||||
| void BinaryCrossEntropyGrad::Init(const std::string &reduction) { set_reduction(reduction); } | |||||
| void BinaryCrossEntropyGrad::set_reduction(const std::string &reduction) { | |||||
| CheckAndConvertUtils::CheckString(kReduction, reduction, {"none", "mean", "sum"}, name()); | |||||
| this->AddAttr(kReduction, MakeValue(reduction)); | |||||
| } | |||||
| std::string BinaryCrossEntropyGrad::get_reduction() const { | |||||
| auto value_ptr = GetAttr(kReduction); | |||||
| return GetValue<std::string>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropyGrad, BinaryCrossEntropyGrad); | |||||
| } // namespace mindspore | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/broadcast.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| void Broadcast::Init(int64_t root_rank, const std::string &group) { | |||||
| this->set_root_rank(root_rank); | |||||
| this->set_group(group); | |||||
| } | |||||
| void Broadcast::set_root_rank(int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); } | |||||
| void Broadcast::set_group(const std::string &group) { | |||||
| CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name()); | |||||
| this->AddAttr(kGroup, MakeValue(group)); | |||||
| } | |||||
| int64_t Broadcast::get_root_rank() { | |||||
| auto value_ptr = this->GetAttr(kRootRank); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| std::string Broadcast::get_group() const { | |||||
| auto value_ptr = this->GetAttr(kGroup); | |||||
| return GetValue<std::string>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast); | |||||
| } // namespace mindspore | |||||
| @@ -1,33 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/broadcast_to.h" | |||||
| namespace mindspore { | |||||
| void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); } | |||||
| void BroadcastTo::set_shape(const std::vector<int64_t> &shape) { | |||||
| CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name()); | |||||
| CheckAndConvertUtils::CheckPositiveVector(kShape, shape, name(), false, true); | |||||
| AddAttr(kShape, MakeValue(shape)); | |||||
| } | |||||
| std::vector<int64_t> BroadcastTo::get_shape() const { | |||||
| auto value_ptr = GetAttr(kShape); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/ceil.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameCeil, Ceil); | |||||
| } | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/cos.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameCos, Cos); | |||||
| } | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/custom_predict.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| void CustomPredict::Init(int64_t outputNum, float weight_threshold) { | |||||
| this->set_outputNum(outputNum); | |||||
| this->set_weight_threshold(weight_threshold); | |||||
| } | |||||
| void CustomPredict::set_outputNum(int64_t outputNum) { this->AddAttr(kOutputNum, MakeValue(outputNum)); } | |||||
| int64_t CustomPredict::get_outputNum() const { | |||||
| auto value_ptr = this->GetAttr(kOutputNum); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void CustomPredict::set_weight_threshold(float weight_threshold) { | |||||
| this->AddAttr(kWeightThreshold, MakeValue(weight_threshold)); | |||||
| } | |||||
| float CustomPredict::get_weight_threshold() const { | |||||
| auto value_ptr = this->GetAttr(kWeightThreshold); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameCustomPredict, CustomPredict); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/div.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameDiv, Div); | |||||
| } // namespace mindspore | |||||
| @@ -1,20 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/equal.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameEqual, Equal); | |||||
| } // namespace mindspore | |||||
| @@ -1,20 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/exp.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameExp, Exp); | |||||
| } // namespace mindspore | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/fake_quant_with_min_max_vars.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| void FakeQuantWithMinMaxVars::Init(const bool &narrow_range, int64_t num_bits) { | |||||
| this->set_narrow_range(narrow_range); | |||||
| this->set_num_bits(num_bits); | |||||
| } | |||||
| void FakeQuantWithMinMaxVars::set_narrow_range(const bool &narrow_range) { | |||||
| this->AddAttr(kNarrowRange, MakeValue(narrow_range)); | |||||
| } | |||||
| bool FakeQuantWithMinMaxVars::get_narrow_range() const { | |||||
| auto value_ptr = this->GetAttr(kNarrowRange); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void FakeQuantWithMinMaxVars::set_num_bits(int64_t num_bits) { this->AddAttr(kNumBits, MakeValue(num_bits)); } | |||||
| int64_t FakeQuantWithMinMaxVars::get_num_bits() const { | |||||
| auto value_ptr = this->GetAttr(kNumBits); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameFakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars); | |||||
| } // namespace mindspore | |||||
| @@ -1,22 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/fft_imag.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameFftImag, FftImag); | |||||
| } | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/flatten_grad.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameFlattenGrad, FlattenGrad); | |||||
| } | |||||
| @@ -1,22 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/hashtable_lookup.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameHashtableLookup, HashtableLookup); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/less.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameLess, Less); | |||||
| } | |||||
| @@ -1,20 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/less_equal.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameLessEqual, LessEqual); | |||||
| } // namespace mindspore | |||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/local_response_normalization.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| void LocalResponseNormalization::set_depth_radius(const int64_t &depth_radius) { | |||||
| this->AddAttr(kDepthRadius, MakeValue(depth_radius)); | |||||
| } | |||||
| int64_t LocalResponseNormalization::get_depth_radius() const { | |||||
| auto value_ptr = GetAttr(kDepthRadius); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void LocalResponseNormalization::set_bias(const float &bias) { this->AddAttr(kBias, MakeValue(bias)); } | |||||
| float LocalResponseNormalization::get_bias() const { | |||||
| auto value_ptr = GetAttr(kBias); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void LocalResponseNormalization::set_alpha(const float &alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } | |||||
| float LocalResponseNormalization::get_alpha() const { | |||||
| auto value_ptr = GetAttr(kAlpha); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void LocalResponseNormalization::set_beta(const float &beta) { this->AddAttr(kBeta, MakeValue(beta)); } | |||||
| float LocalResponseNormalization::get_beta() const { | |||||
| auto value_ptr = GetAttr(kBeta); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void LocalResponseNormalization::Init(const int64_t &depth_radius, const float &bias, const float &alpha, | |||||
| const float &beta) { | |||||
| this->set_depth_radius(depth_radius); | |||||
| this->set_bias(bias); | |||||
| this->set_alpha(alpha); | |||||
| this->set_beta(beta); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameLocalResponseNormalization, LocalResponseNormalization); | |||||
| } // namespace mindspore | |||||
| @@ -1,21 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/log.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameLog, Log); | |||||
| } | |||||
| @@ -1,20 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/logical_not.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameLogicalNot, LogicalNot); | |||||
| } // namespace mindspore | |||||
| @@ -1,20 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/logical_or.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PRIMITIVE_C(kNameLogicalOr, LogicalOr); | |||||
| } // namespace mindspore | |||||
| @@ -1,82 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "c_ops/lstm.h" | |||||
| namespace mindspore { | |||||
| void LSTM::set_input_size(const int64_t &input_size) { | |||||
| CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); | |||||
| AddAttr(kInput_size, MakeValue(input_size)); | |||||
| } | |||||
| int64_t LSTM::get_input_size() const { | |||||
| auto value_ptr = this->GetAttr(kInput_size); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void LSTM::set_hidden_size(const int64_t &hidden_size) { | |||||
| CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); | |||||
| AddAttr(kHidden_size, MakeValue(hidden_size)); | |||||
| } | |||||
| int64_t LSTM::get_hidden_size() const { | |||||
| auto value_ptr = this->GetAttr(kHidden_size); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void LSTM::set_num_layers(const int64_t &num_layers) { | |||||
| CheckAndConvertUtils::CheckInteger(kNum_layers, num_layers, kGreaterThan, 0, this->name()); | |||||
| AddAttr(kNum_layers, MakeValue(kNum_layers)); | |||||
| } | |||||
| int64_t LSTM::get_num_layers() const { | |||||
| auto value_ptr = this->GetAttr(kNum_layers); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void LSTM::set_has_bias(const bool &has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } | |||||
| bool LSTM::get_has_bias() const { | |||||
| auto value_ptr = this->GetAttr(kHasBias); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void LSTM::set_dropout(const float &dropout) { | |||||
| CheckAndConvertUtils::CheckInRange(kDropout, dropout, kIncludeBoth, {0, 1}, this->name()); | |||||
| AddAttr(kDropout, MakeValue(dropout)); | |||||
| } | |||||
| float LSTM::get_dropout() const { | |||||
| auto value_ptr = this->GetAttr(kDropout); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void LSTM::set_bidirectional(const bool &bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); } | |||||
| bool LSTM::get_bidirectional() const { | |||||
| auto value_ptr = this->GetAttr(kBidirectional); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void LSTM::set_num_directions(const int64_t &num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } | |||||
| int64_t LSTM::get_num_directions() const { | |||||
| auto value_ptr = this->GetAttr(kNumDirections); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void LSTM::Init(const int64_t &input_size, const int64_t &hidden_size, const int64_t &num_layers, const bool &has_bias, | |||||
| const float &dropout, const bool &bidirectional) { | |||||
| this->set_input_size(input_size); | |||||
| this->set_hidden_size(hidden_size); | |||||
| this->set_num_layers(num_layers); | |||||
| this->set_has_bias(has_bias); | |||||
| this->set_dropout(dropout); | |||||
| this->set_bidirectional(bidirectional); | |||||
| if (bidirectional) { | |||||
| this->set_num_directions(2); | |||||
| } else { | |||||
| this->set_num_directions(1); | |||||
| } | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameLSTM, LSTM); | |||||
| } // namespace mindspore | |||||
| @@ -25,7 +25,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| @@ -676,7 +676,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc | |||||
| const std::string &node_type = node_proto.op_type(); | const std::string &node_type = node_proto.op_type(); | ||||
| std::shared_ptr<Primitive> prim; | std::shared_ptr<Primitive> prim; | ||||
| auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); | |||||
| auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); | |||||
| if (op_primc_fns.find(node_type) != op_primc_fns.end()) { | if (op_primc_fns.find(node_type) != op_primc_fns.end()) { | ||||
| prim = op_primc_fns[node_type](); | prim = op_primc_fns[node_type](); | ||||
| } else { | } else { | ||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/abs.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs_prim = primitive->cast<PrimAbsPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(abs_prim); | |||||
| auto prim_name = abs_prim->name(); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| return std::make_shared<abstract::Shape>(in_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||||
| MS_LOG(EXCEPTION) << "nullptr"; | |||||
| } | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("input_x", input_args[0]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAbs, Abs); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -16,11 +16,15 @@ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ABS_H_ | #ifndef MINDSPORE_CORE_C_OPS_ABS_H_ | ||||
| #define MINDSPORE_CORE_C_OPS_ABS_H_ | #define MINDSPORE_CORE_C_OPS_ABS_H_ | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAbs = "Abs"; | constexpr auto kNameAbs = "Abs"; | ||||
| class Abs : public PrimitiveC { | class Abs : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,6 +33,10 @@ class Abs : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Abs, PrimitiveC); | MS_DECLARE_PARENT(Abs, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAbsPtr = std::shared_ptr<Abs>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ABS_H_ | #endif // MINDSPORE_CORE_C_OPS_ABS_H_ | ||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/adam.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto Adam_prim = primitive->cast<PrimAdamPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Adam_prim); | |||||
| auto prim_name = Adam_prim->name(); | |||||
| // infer shape | |||||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name); | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name); | |||||
| auto grad_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name); | |||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | |||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | |||||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | |||||
| // infer type | |||||
| auto var_type = input_args[0]->BuildType(); | |||||
| auto m_type = input_args[1]->BuildType(); | |||||
| auto v_type = input_args[2]->BuildType(); | |||||
| auto grad_type = input_args[9]->BuildType(); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name); | |||||
| auto infer_var_type = var_type->cast<TensorTypePtr>()->element(); | |||||
| auto infer_m_type = m_type->cast<TensorTypePtr>()->element(); | |||||
| auto infer_v_type = v_type->cast<TensorTypePtr>()->element(); | |||||
| // auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element(); | |||||
| auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape); | |||||
| auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape); | |||||
| auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape); | |||||
| AbstractBasePtrList output = {output0, output1, output2}; | |||||
| return std::make_shared<abstract::AbstractTuple>(output); | |||||
| } | |||||
| } // namespace | |||||
| void Adam::Init(const bool use_locking, const bool use_nesterov) { | |||||
| this->set_use_locking(use_locking); | |||||
| this->set_use_nesterov(use_nesterov); | |||||
| } | |||||
| void Adam::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||||
| void Adam::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||||
| bool Adam::get_use_locking() const { | |||||
| auto value_ptr = GetAttr(kUseLocking); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool Adam::get_use_nesterov() const { | |||||
| auto value_ptr = GetAttr(kUseNesterov); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAdam, Adam); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -20,23 +20,28 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAdam = "Adam"; | constexpr auto kNameAdam = "Adam"; | ||||
| class Adam : public PrimitiveC { | class Adam : public PrimitiveC { | ||||
| public: | public: | ||||
| Adam() : PrimitiveC(kNameAdam) {} | Adam() : PrimitiveC(kNameAdam) {} | ||||
| ~Adam() = default; | ~Adam() = default; | ||||
| MS_DECLARE_PARENT(Adam, PrimitiveC); | MS_DECLARE_PARENT(Adam, PrimitiveC); | ||||
| void Init(const bool &use_locking = false, const bool &use_nesteroy = false); | |||||
| void set_use_locking(const bool &use_locking); | |||||
| void set_use_nesteroy(const bool &use_nesteroy); | |||||
| void Init(const bool use_locking = false, const bool use_nesterov = false); | |||||
| void set_use_locking(const bool use_locking); | |||||
| void set_use_nesterov(const bool use_nesterov); | |||||
| bool get_use_locking() const; | bool get_use_locking() const; | ||||
| bool get_use_nesteroy() const; | |||||
| bool get_use_nesterov() const; | |||||
| }; | }; | ||||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAdamPtr = std::shared_ptr<Adam>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ADAM_H_ | #endif // MINDSPORE_CORE_C_OPS_ADAM_H_ | ||||
| @@ -14,22 +14,23 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/add.h" | |||||
| #include "ops/add.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto add_prim = primitive->cast<PrimAddPtr>(); | auto add_prim = primitive->cast<PrimAddPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(add_prim); | MS_EXCEPTION_IF_NULL(add_prim); | ||||
| auto op_name = add_prim->name(); | |||||
| return BroadCastInferShape(op_name, input_args); | |||||
| auto prim_name = add_prim->name(); | |||||
| return BroadCastInferShape(prim_name, input_args); | |||||
| } | } | ||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -49,6 +50,7 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimTensorAdd, AddInfer); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | REGISTER_PRIMITIVE_C(kNameAdd, Add); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,15 +20,17 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAdd = "Add"; | constexpr auto kNameAdd = "Add"; | ||||
| class Add : public PrimitiveC { | class Add : public PrimitiveC { | ||||
| public: | public: | ||||
| Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); } | Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); } | ||||
| explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); } | |||||
| ~Add() = default; | ~Add() = default; | ||||
| MS_DECLARE_PARENT(Add, PrimitiveC); | MS_DECLARE_PARENT(Add, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| @@ -37,6 +39,7 @@ class Add : public PrimitiveC { | |||||
| AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimAddPtr = std::shared_ptr<Add>; | using PrimAddPtr = std::shared_ptr<Add>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ADD_H_ | #endif // MINDSPORE_CORE_C_OPS_ADD_H_ | ||||
| @@ -14,8 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/add_fold.h" | |||||
| #include "ops/add_fold.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| REGISTER_PRIMITIVE_C(kNameAddFold, AddFold); | REGISTER_PRIMITIVE_C(kNameAddFold, AddFold); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,17 +14,18 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ADD_FOLD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ADD_FOLD_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAddFold = "AddFold"; | constexpr auto kNameAddFold = "AddFold"; | ||||
| class AddFold : public PrimitiveC { | class AddFold : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -33,6 +34,7 @@ class AddFold : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(AddFold, PrimitiveC); | MS_DECLARE_PARENT(AddFold, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_ADD_FOLD_H_ | |||||
| @@ -0,0 +1,108 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/adder.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size, | |||||
| const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list, | |||||
| const std::vector<int64_t> &dilation, const int64_t group, const Format &format) { | |||||
| set_in_channel(in_channel); | |||||
| set_out_channel(out_channel); | |||||
| set_kernel_size(kernel_size); | |||||
| set_pad_mode(pad_mode); | |||||
| set_stride(stride); | |||||
| set_pad_list(pad_list); | |||||
| set_dilation(dilation); | |||||
| set_group(group); | |||||
| set_format(format); | |||||
| } | |||||
| void Adder::set_in_channel(const int64_t in_channel) { this->AddAttr(kInChannel, MakeValue(in_channel)); } | |||||
| int64_t Adder::get_in_channel() const { | |||||
| auto value_ptr = GetAttr(kInChannel); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void Adder::set_out_channel(const int64_t out_channel) { this->AddAttr(kOutChannel, MakeValue(out_channel)); } | |||||
| int64_t Adder::get_out_channel() const { | |||||
| auto value_ptr = GetAttr(kOutChannel); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||||
| } | |||||
| std::vector<int64_t> Adder::get_kernel_size() const { | |||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void Adder::set_pad_mode(const PadMode &pad_mode) { | |||||
| int64_t swi = pad_mode; | |||||
| this->AddAttr(kPadMode, MakeValue(swi)); | |||||
| } | |||||
| PadMode Adder::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| void Adder::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||||
| std::vector<int64_t> Adder::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||||
| std::vector<int64_t> Adder::get_pad_list() const { | |||||
| auto value_ptr = GetAttr(kPadList); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void Adder::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||||
| std::vector<int64_t> Adder::get_dilation() const { | |||||
| auto value_ptr = GetAttr(kDilation); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void Adder::set_group(const int64_t group) { this->AddAttr(kGroup, MakeValue(group)); } | |||||
| int64_t Adder::get_group() const { | |||||
| auto value_ptr = GetAttr(kGroup); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void Adder::set_format(const Format &format) { | |||||
| int64_t swi = format; | |||||
| this->AddAttr(kFormat, MakeValue(swi)); | |||||
| } | |||||
| Format Adder::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNameAdder, Adder); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ADDER_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ADDER_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameAdder = "Adder"; | |||||
| class Adder : public PrimitiveC { | |||||
| public: | |||||
| explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {} | |||||
| ~Adder() = default; | |||||
| MS_DECLARE_PARENT(Adder, PrimitiveC); | |||||
| void Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size, | |||||
| const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list, | |||||
| const std::vector<int64_t> &dilation, const int64_t group, const Format &format); | |||||
| void set_in_channel(const int64_t in_channel); | |||||
| void set_out_channel(const int64_t out_channel); | |||||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | |||||
| void set_pad_mode(const PadMode &pad_mode); | |||||
| void set_stride(const std::vector<int64_t> &stride); | |||||
| void set_pad_list(const std::vector<int64_t> &pad_list); | |||||
| void set_dilation(const std::vector<int64_t> &dilation); | |||||
| void set_group(const int64_t group); | |||||
| void set_format(const Format &format); | |||||
| int64_t get_in_channel() const; | |||||
| int64_t get_out_channel() const; | |||||
| std::vector<int64_t> get_kernel_size() const; | |||||
| PadMode get_pad_mode() const; | |||||
| std::vector<int64_t> get_stride() const; | |||||
| std::vector<int64_t> get_pad_list() const; | |||||
| std::vector<int64_t> get_dilation() const; | |||||
| int64_t get_group() const; | |||||
| Format get_format() const; | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_ADDER_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "ops/addn.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_tuple); | |||||
| auto elements = input_tuple->elements(); | |||||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | |||||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(element0); | |||||
| auto element0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("element0", element0->BuildType()); | |||||
| for (size_t i = 1; i < elements.size(); ++i) { | |||||
| std::string elementi = "element" + std::to_string(i); | |||||
| auto elementi_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | |||||
| prim_name); | |||||
| for (size_t j = 0; j < element0_shape.size(); ++j) { | |||||
| if (elementi_shape[j] != element0_shape[j]) { | |||||
| MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element."; | |||||
| } | |||||
| } | |||||
| types.emplace(elementi, elements[i]->BuildType()); | |||||
| } | |||||
| std::set<TypeId> valid_types = common_valid_types; | |||||
| valid_types.insert(kNumberTypeBool); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type), | |||||
| std::make_shared<abstract::Shape>(element0_shape)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAddN, AddN); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -16,11 +16,14 @@ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ADDN_H_ | #ifndef MINDSPORE_CORE_C_OPS_ADDN_H_ | ||||
| #define MINDSPORE_CORE_C_OPS_ADDN_H_ | #define MINDSPORE_CORE_C_OPS_ADDN_H_ | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAddN = "AddN"; | constexpr auto kNameAddN = "AddN"; | ||||
| class AddN : public PrimitiveC { | class AddN : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,6 +32,10 @@ class AddN : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(AddN, PrimitiveC); | MS_DECLARE_PARENT(AddN, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAddNPtr = std::shared_ptr<AddN>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ADDN_H_ | #endif // MINDSPORE_CORE_C_OPS_ADDN_H_ | ||||
| @@ -14,19 +14,20 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/concat.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/all.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); } | |||||
| void Concat::Init(int64_t axis) { this->set_axis(axis); } | |||||
| int64_t Concat::get_axis() const { | |||||
| auto value_ptr = this->GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void All::set_keep_dims(const int64_t keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); } | |||||
| void Concat::set_axis(int64_t axis) { | |||||
| this->AddAttr(kAxis, MakeValue(CheckAndConvertUtils::CheckInteger(kAxis, axis, kGreaterEqual, 0, this->name()))); | |||||
| int64_t All::get_keep_dims() const { | |||||
| auto value_ptr = GetAttr(kKeepDims); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameConcat, Concat); | |||||
| REGISTER_PRIMITIVE_C(kNameAll, All); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ALL_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ALL_H_ | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameAll = "All"; | |||||
| class All : public PrimitiveC { | |||||
| public: | |||||
| All() : PrimitiveC(kNameAll) {} | |||||
| ~All() = default; | |||||
| MS_DECLARE_PARENT(All, PrimitiveC); | |||||
| void Init(const int64_t keep_dims); | |||||
| void set_keep_dims(const int64_t keep_dims); | |||||
| int64_t get_keep_dims() const; | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_ALL_H_ | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "ops/apply_momentum.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) { | |||||
| this->set_use_nesterov(use_nesterov); | |||||
| this->set_use_locking(use_locking); | |||||
| this->set_gradient_scale(gradient_scale); | |||||
| } | |||||
| void ApplyMomentum::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||||
| void ApplyMomentum::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||||
| void ApplyMomentum::set_gradient_scale(const float gradient_scale) { | |||||
| this->AddAttr(kGradientScale, MakeValue(gradient_scale)); | |||||
| } | |||||
| bool ApplyMomentum::get_use_nesterov() const { | |||||
| auto value_ptr = GetAttr(kUseNesterov); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| bool ApplyMomentum::get_use_locking() const { | |||||
| auto value_ptr = GetAttr(kUseLocking); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| float ApplyMomentum::get_gradient_scale() const { | |||||
| auto value_ptr = GetAttr(kGradientScale); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(momentum_prim); | |||||
| auto prim_name = momentum_prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | |||||
| // Infer shape | |||||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | |||||
| // Infer type | |||||
| auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name); | |||||
| const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32), | |||||
| TypeIdToType(kNumberTypeFloat64)}; | |||||
| std::map<std::string, TypePtr> args; | |||||
| args.insert({"l_type", l_type}); | |||||
| args.insert({"g_type", g_type}); | |||||
| args.insert({"m_type", m_type}); | |||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(g_type, v_shape); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,13 +14,17 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||||
| #include "c_ops/primitive_c.h" | |||||
| #ifndef MINDSPORE_CORE_C_OPS_APPLY_MOMENTUM_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_APPLY_MOMENTUM_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameApplyMomentum = "ApplyMomentum"; | constexpr auto kNameApplyMomentum = "ApplyMomentum"; | ||||
| class ApplyMomentum : public PrimitiveC { | class ApplyMomentum : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,14 +33,18 @@ class ApplyMomentum : public PrimitiveC { | |||||
| } | } | ||||
| ~ApplyMomentum() = default; | ~ApplyMomentum() = default; | ||||
| MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); | MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); | ||||
| void Init(bool use_nesterov, bool use_locking, float gradient_scale); | |||||
| void set_use_nesterov(bool use_nesterov); | |||||
| void set_use_locking(bool use_locking); | |||||
| void set_gradient_scale(float gradient_scale); | |||||
| void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0); | |||||
| void set_use_nesterov(const bool use_nesterov); | |||||
| void set_use_locking(const bool use_locking); | |||||
| void set_gradient_scale(const float gradient_scale); | |||||
| bool get_use_nesterov() const; | bool get_use_nesterov() const; | ||||
| bool get_use_locking() const; | bool get_use_locking() const; | ||||
| float get_gradient_scale(); | |||||
| float get_gradient_scale() const; | |||||
| }; | }; | ||||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_APPLY_MOMENTUM_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/arg_max.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| auto prim = primitive->cast<PrimArgMaxPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto axis = prim->get_axis(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_rank = SizeToLong(x_shape.size()); | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | |||||
| axis = axis < 0 ? axis + x_rank : axis; | |||||
| std::vector<int64_t> out_shape; | |||||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||||
| if (SizeToLong(i) != axis) { | |||||
| out_shape.emplace_back(x_shape[i]); | |||||
| } | |||||
| } | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| return kInt32; | |||||
| } | |||||
| } // namespace | |||||
| void ArgMax::Init(const int64_t axis, const TypeId output_type) { | |||||
| set_axis(axis); | |||||
| set_output_type(output_type); | |||||
| } | |||||
| void ArgMax::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||||
| void ArgMax::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | |||||
| int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||||
| TypeId ArgMax::get_output_type() const { | |||||
| auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | |||||
| return type_ptr->type_id(); | |||||
| } | |||||
| AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ArgMax, prim::kPrimArgMax, ArgMaxInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ARG_MAX_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ARG_MAX_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameArgMax = "Argmax"; | |||||
| class ArgMax : public PrimitiveC { | |||||
| public: | |||||
| ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); } | |||||
| explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||||
| ~ArgMax() = default; | |||||
| MS_DECLARE_PARENT(ArgMax, PrimitiveC); | |||||
| void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32); | |||||
| void set_axis(const int64_t axis); | |||||
| void set_output_type(const TypeId output_type); | |||||
| int64_t get_axis() const; | |||||
| TypeId get_output_type() const; | |||||
| }; | |||||
| AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimArgMaxPtr = std::shared_ptr<ArgMax>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_ARG_MAX_H_ | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include "ops/arg_min.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void ArgMin::Init(const int64_t axis, const TypeId output_type) { | |||||
| set_axis(axis); | |||||
| set_output_type(output_type); | |||||
| } | |||||
| void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||||
| void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | |||||
| int64_t ArgMin::get_axis() const { | |||||
| auto value_ptr = GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| TypeId ArgMin::get_output_type() const { | |||||
| auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | |||||
| return type_ptr->type_id(); | |||||
| } | |||||
| AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto argmin_prim = primitive->cast<PrimArgMin>(); | |||||
| MS_EXCEPTION_IF_NULL(argmin_prim); | |||||
| auto prim_name = argmin_prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name); | |||||
| // Infer shape | |||||
| auto axis = argmin_prim->get_axis(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto x_rank = SizeToLong(x_shape.size()); | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | |||||
| if (axis < 0) { | |||||
| axis += x_rank; | |||||
| } | |||||
| std::vector<int64_t> out_shape; | |||||
| for (int64_t i = 0; i < x_rank; i++) { | |||||
| if (i != axis) { | |||||
| out_shape.push_back(x_shape[i]); | |||||
| } | |||||
| } | |||||
| // Infer type | |||||
| auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||||
| CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name); | |||||
| return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ArgMin, prim::kPrimArgMin, ArgMinInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,32 +14,37 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ARG_MIN_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ARG_MIN_H_ | |||||
| #include <string> | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameArgMin = "ArgMin"; | constexpr auto kNameArgMin = "ArgMin"; | ||||
| class ArgMin : public PrimitiveC { | class ArgMin : public PrimitiveC { | ||||
| public: | public: | ||||
| ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); } | ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); } | ||||
| explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||||
| ~ArgMin() = default; | ~ArgMin() = default; | ||||
| MS_DECLARE_PARENT(ArgMin, PrimitiveC); | MS_DECLARE_PARENT(ArgMin, PrimitiveC); | ||||
| void Init(bool keep_dims, int64_t axis = -1); | |||||
| void set_axis(int64_t axis); | |||||
| void set_keep_dims(bool keep_dims); | |||||
| int64_t get_axis(); | |||||
| bool get_keep_dims(); | |||||
| void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32); | |||||
| void set_axis(const int64_t axis); | |||||
| void set_output_type(const TypeId output_type); | |||||
| int64_t get_axis() const; | |||||
| TypeId get_output_type() const; | |||||
| }; | }; | ||||
| AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimArgMin = std::shared_ptr<ArgMin>; | using PrimArgMin = std::shared_ptr<ArgMin>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_ARG_MIN_H_ | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/asin.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto asin_prim = primitive->cast<PrimAsinPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(asin_prim); | |||||
| auto prim_name = asin_prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | |||||
| // Infer Shape | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | |||||
| // Infer Type | |||||
| auto dtype = input_args[0]->BuildType(); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); | |||||
| auto tensor_type = dtype->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto element = tensor_type->element(); | |||||
| MS_EXCEPTION_IF_NULL(element); | |||||
| auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id())); | |||||
| return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Asin, prim::kPrimAsin, AsinInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAsin, Asin); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -16,11 +16,15 @@ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ASIN_H_ | #ifndef MINDSPORE_CORE_C_OPS_ASIN_H_ | ||||
| #define MINDSPORE_CORE_C_OPS_ASIN_H_ | #define MINDSPORE_CORE_C_OPS_ASIN_H_ | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAsin = "Asin"; | constexpr auto kNameAsin = "Asin"; | ||||
| class Asin : public PrimitiveC { | class Asin : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,6 +33,10 @@ class Asin : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Asin, PrimitiveC); | MS_DECLARE_PARENT(Asin, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr ASinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAsinPtr = std::shared_ptr<Asin>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ASIN_H_ | #endif // MINDSPORE_CORE_C_OPS_ASIN_H_ | ||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/assert.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void Assert::Init(const int64_t summarize) { set_summarize(summarize); } | |||||
| void Assert::set_summarize(const int64_t summarize) { this->AddAttr(kSummarize, MakeValue(summarize)); } | |||||
| int64_t Assert::get_summarize() const { | |||||
| auto value_ptr = GetAttr(kSummarize); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto Assert_prim = primitive->cast<PrimAssertPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(Assert_prim); | |||||
| auto op_name = Assert_prim->name(); | |||||
| TypePtr condition; | |||||
| if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) { | |||||
| auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue()); | |||||
| CheckAndConvertUtils::CheckInteger("condition's rank", condition_value.size(), kLessEqual, 1, op_name); | |||||
| if (condition_value.size() == 1) { | |||||
| CheckAndConvertUtils::CheckInteger("condition[0]", condition_value[0], kEqual, 1, op_name); | |||||
| } | |||||
| condition = TypeIdToType(kNumberTypeBool); | |||||
| } else { | |||||
| auto condition_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||||
| CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); | |||||
| if (condition_shape[0] == 1) { | |||||
| auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c()); | |||||
| MS_EXCEPTION_IF_NULL(condition_value); | |||||
| // auto condition_value = GetValue<bool>(input_args[0]->BuildValue()); | |||||
| CheckAndConvertUtils::CheckInteger("condition[0]", *condition_value, kEqual, 1, op_name); | |||||
| } | |||||
| condition = input_args[0]->BuildType(); | |||||
| } | |||||
| std::vector<int64_t> output_shape = {1}; | |||||
| std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)}; | |||||
| std::map<std::string, TypePtr> args = {{"condition", condition}}; | |||||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); | |||||
| auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); | |||||
| for (auto dtype : inputs_type) { | |||||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||||
| CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name); | |||||
| } | |||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assert, prim::kPrimAssert, AssertInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAssert, Assert); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ASSERT_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ASSERT_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameAssert = "Assert"; | |||||
| class Assert : public PrimitiveC { | |||||
| public: | |||||
| Assert() : PrimitiveC(kNameAssert) {} | |||||
| ~Assert() = default; | |||||
| MS_DECLARE_PARENT(Assert, PrimitiveC); | |||||
| void Init(const int64_t summarize = 3); | |||||
| void set_summarize(const int64_t summarize); | |||||
| int64_t get_summarize() const; | |||||
| }; | |||||
| AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAssertPtr = std::shared_ptr<Assert>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_ASSERT_H_ | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ops/assign.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "ir/dtype/ref.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| REGISTER_PRIMITIVE_C(kNameAssign, Assign); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -16,11 +16,15 @@ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ASSIGN_H_ | #ifndef MINDSPORE_CORE_C_OPS_ASSIGN_H_ | ||||
| #define MINDSPORE_CORE_C_OPS_ASSIGN_H_ | #define MINDSPORE_CORE_C_OPS_ASSIGN_H_ | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAssign = "Assign"; | constexpr auto kNameAssign = "Assign"; | ||||
| class Assign : public PrimitiveC { | class Assign : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,6 +33,7 @@ class Assign : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Assign, PrimitiveC); | MS_DECLARE_PARENT(Assign, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ASSIGN_H_ | #endif // MINDSPORE_CORE_C_OPS_ASSIGN_H_ | ||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "ops/assign_add.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto assignadd_prim = primitive->cast<PrimAssignAddPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(assignadd_prim); | |||||
| auto prim_name = assignadd_prim->name(); | |||||
| auto value_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); | |||||
| return std::make_shared<abstract::Shape>(value_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x", input_args[0]->BuildType()); | |||||
| types.emplace("w", input_args[1]->BuildType()); | |||||
| // check_scalar_or_tensor_types_same | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd"); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,13 +14,17 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||||
| #include "c_ops/primitive_c.h" | |||||
| #ifndef MINDSPORE_CORE_C_OPS_ASSIGN_ADD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_ASSIGN_ADD_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAssignAdd = "AssignAdd"; | constexpr auto kNameAssignAdd = "AssignAdd"; | ||||
| class AssignAdd : public PrimitiveC { | class AssignAdd : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -29,6 +33,10 @@ class AssignAdd : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(AssignAdd, PrimitiveC); | MS_DECLARE_PARENT(AssignAdd, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAssignAddPtr = std::shared_ptr<AssignAdd>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_ASSIGN_ADD_H_ | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include "ops/atan.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto atan_prim = primitive->cast<PrimAtanPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(atan_prim); | |||||
| auto prim_name = atan_prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | |||||
| // Infer Shape | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | |||||
| // Infer Type | |||||
| auto dtype = input_args[0]->BuildType(); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); | |||||
| auto tensor_type = dtype->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto element = tensor_type->element(); | |||||
| MS_EXCEPTION_IF_NULL(element); | |||||
| auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id())); | |||||
| return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Atan, prim::kPrimAtan, AtanInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAtan, Atan); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -20,11 +20,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAtan = "Atan"; | constexpr auto kNameAtan = "Atan"; | ||||
| class Atan : public PrimitiveC { | class Atan : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -33,6 +34,10 @@ class Atan : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(Atan, PrimitiveC); | MS_DECLARE_PARENT(Atan, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr ATanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAtanPtr = std::shared_ptr<Atan>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_ATAN_H_ | #endif // MINDSPORE_CORE_C_OPS_ATAN_H_ | ||||
| @@ -0,0 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/audio_spectrogram.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(audio_spectrogram_prim); | |||||
| auto prim_name = audio_spectrogram_prim->name(); | |||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||||
| if (input_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | |||||
| } | |||||
| if (audio_spectrogram_prim->get_window_size() < 2) { | |||||
| MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size(); | |||||
| } | |||||
| if (audio_spectrogram_prim->get_stride() < 1) { | |||||
| MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride(); | |||||
| } | |||||
| std::vector<int64_t> infer_shape; | |||||
| infer_shape.push_back(input_shape[1]); | |||||
| int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size(); | |||||
| infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride()); | |||||
| int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size()); | |||||
| infer_shape.push_back(fft_length / 2 + 1); | |||||
| MS_LOG(ERROR) << infer_shape; | |||||
| return std::make_shared<abstract::Shape>(infer_shape); | |||||
| } | |||||
| TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto infer_type = input_args[0]->BuildType(); | |||||
| auto tensor_type = infer_type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto data_type = tensor_type->element(); | |||||
| MS_EXCEPTION_IF_NULL(data_type); | |||||
| return data_type; | |||||
| } | |||||
| } // namespace | |||||
| void AudioSpectrogram::set_window_size(const int64_t window_size) { | |||||
| this->AddAttr(kWindowSize, MakeValue(window_size)); | |||||
| } | |||||
| int64_t AudioSpectrogram::get_window_size() const { | |||||
| auto value_ptr = GetAttr(kWindowSize); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void AudioSpectrogram::set_stride(const int64_t stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||||
| int64_t AudioSpectrogram::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t AudioSpectrogram::Log2Ceil(int64_t length) { | |||||
| if (length == 0) { | |||||
| return -1; | |||||
| } | |||||
| int64_t floor = 0; | |||||
| for (int64_t i = 4; i >= 0; --i) { | |||||
| const int64_t shift = (int64_t)(1 << i); | |||||
| int64_t tmp = length >> shift; | |||||
| if (tmp != 0) { | |||||
| length = tmp; | |||||
| floor += shift; | |||||
| } | |||||
| } | |||||
| return length == (length & ~(length - 1)) ? floor : floor + 1; | |||||
| } | |||||
| int64_t AudioSpectrogram::GetFftLength(int64_t length) { | |||||
| int64_t shift = Log2Ceil(length); | |||||
| return 1 << shift; | |||||
| } | |||||
| void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); } | |||||
| bool AudioSpectrogram::get_mag_square() const { | |||||
| auto value_ptr = GetAttr(kMagSquare); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, const bool mag_square) { | |||||
| this->set_window_size(window_size); | |||||
| this->set_stride(stride); | |||||
| this->set_mag_square(mag_square); | |||||
| } | |||||
| AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args), | |||||
| AudioSpectrogramInferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AudioSpectrogram, prim::kPrimAudioSpectrogram, AudioSpectrogramInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,31 +14,38 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; | constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; | ||||
| class AudioSpectrogram : public PrimitiveC { | class AudioSpectrogram : public PrimitiveC { | ||||
| public: | public: | ||||
| AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} | AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} | ||||
| ~AudioSpectrogram() = default; | ~AudioSpectrogram() = default; | ||||
| MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); | MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); | ||||
| void Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square); | |||||
| void set_window_size(const int64_t &window_size); | |||||
| void set_stride(const int64_t &stride); | |||||
| void set_mag_square(const bool &mag_square); | |||||
| void Init(const int64_t window_size, const int64_t stride, const bool mag_square); | |||||
| void set_window_size(const int64_t window_size); | |||||
| void set_stride(const int64_t stride); | |||||
| void set_mag_square(const bool mag_square); | |||||
| int64_t get_window_size() const; | int64_t get_window_size() const; | ||||
| int64_t get_stride() const; | int64_t get_stride() const; | ||||
| bool get_mag_square() const; | bool get_mag_square() const; | ||||
| int64_t Log2Ceil(int64_t length); | |||||
| int64_t GetFftLength(int64_t length); | |||||
| }; | }; | ||||
| AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimAudioSpectrogramPtr = std::shared_ptr<AudioSpectrogram>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| @@ -14,25 +14,26 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/avg_pool.h" | |||||
| #include "ops/avg_pool.h" | |||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | #include <set> | ||||
| #include <vector> | #include <vector> | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| void AvgPool::set_padding(const std::string &padding) { | |||||
| CheckAndConvertUtils::CheckString(kPadding, padding, {kValid, kSame}, this->name()); | |||||
| this->AddAttr(kPadding, MakeValue(padding)); | |||||
| namespace ops { | |||||
| void AvgPool::set_pad_mode(const PadMode &pad_mode) { | |||||
| int64_t swi = pad_mode; | |||||
| this->AddAttr(kPadMode, MakeValue(swi)); | |||||
| } | } | ||||
| std::string AvgPool::get_padding() const { | |||||
| auto value_ptr = GetAttr(kPadding); | |||||
| return GetValue<std::string>(value_ptr); | |||||
| PadMode AvgPool::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | } | ||||
| void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), | this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), | ||||
| @@ -44,12 +45,12 @@ std::vector<int64_t> AvgPool::get_kernel_size() const { | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| void AvgPool::set_strides(const std::vector<int64_t> &strides) { | void AvgPool::set_strides(const std::vector<int64_t> &strides) { | ||||
| this->AddAttr(kStride, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, strides, this->name(), false, true))); | |||||
| this->AddAttr(kStrides, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); | |||||
| } | } | ||||
| std::vector<int64_t> AvgPool::get_strides() const { | std::vector<int64_t> AvgPool::get_strides() const { | ||||
| auto value_ptr = GetAttr(kStride); | |||||
| auto value_ptr = GetAttr(kStrides); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| @@ -70,20 +71,19 @@ std::vector<int64_t> AvgPool::get_pad() const { | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| void AvgPool::set_round_mode(const int64_t &round_mode) { | |||||
| CheckAndConvertUtils::CheckInRange(kRoundMode, round_mode, kIncludeBoth, {0, 1}, this->name()); | |||||
| this->AddAttr(kRoundMode, MakeValue(round_mode)); | |||||
| void AvgPool::set_round_mode(const RoundMode &round_mode) { | |||||
| int64_t swi = round_mode; | |||||
| this->AddAttr(kRoundMode, MakeValue(swi)); | |||||
| } | } | ||||
| int64_t AvgPool::get_round_mode() const { | |||||
| RoundMode AvgPool::get_round_mode() const { | |||||
| auto value_ptr = GetAttr(kRoundMode); | auto value_ptr = GetAttr(kRoundMode); | ||||
| return GetValue<int64_t>(value_ptr); | |||||
| return RoundMode(GetValue<int64_t>(value_ptr)); | |||||
| } | } | ||||
| void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, | |||||
| const std::string &padding, const Format &format, const std::vector<int64_t> &pad, | |||||
| const int64_t &round_mode) { | |||||
| this->set_padding(padding); | |||||
| void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode, | |||||
| const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) { | |||||
| this->set_pad_mode(pad_mode); | |||||
| this->set_kernel_size(kernel_size); | this->set_kernel_size(kernel_size); | ||||
| this->set_strides(stride); | this->set_strides(stride); | ||||
| this->set_format(format); | this->set_format(format); | ||||
| @@ -98,9 +98,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| MS_EXCEPTION_IF_NULL(pool_prim); | MS_EXCEPTION_IF_NULL(pool_prim); | ||||
| auto op_name = pool_prim->name(); | auto op_name = pool_prim->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = pool_prim->get_kernel_size(); | auto kernel_size = pool_prim->get_kernel_size(); | ||||
| auto pad_mode = pool_prim->get_padding(); | |||||
| auto pad_mode = pool_prim->get_pad_mode(); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||
| @@ -113,14 +116,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| auto stride_w = strides[3]; | auto stride_w = strides[3]; | ||||
| int64_t out_h = -1; | int64_t out_h = -1; | ||||
| int64_t out_w = -1; | int64_t out_w = -1; | ||||
| if (pad_mode == "valid") { | |||||
| if (pad_mode == VALID) { | |||||
| out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | ||||
| out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | ||||
| } else if (pad_mode == "same") { | |||||
| } else if (pad_mode == SAME) { | |||||
| out_h = ceil(in_h / stride_h); | out_h = ceil(in_h / stride_h); | ||||
| out_w = ceil(in_w / stride_w); | out_w = ceil(in_w / stride_w); | ||||
| } | } | ||||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | ||||
| if (pool_prim->get_format() == NHWC) { | |||||
| out_shape = {batch, out_h, out_w, channel}; | |||||
| } | |||||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | ||||
| MS_LOG(EXCEPTION) << "Kernel size is not valid."; | MS_LOG(EXCEPTION) << "Kernel size is not valid."; | ||||
| } | } | ||||
| @@ -142,4 +148,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | ||||
| REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool); | REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,38 +21,41 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameAvgPool = "AvgPool"; | constexpr auto kNameAvgPool = "AvgPool"; | ||||
| class AvgPool : public PrimitiveC { | class AvgPool : public PrimitiveC { | ||||
| public: | public: | ||||
| AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); } | AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); } | ||||
| explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||||
| ~AvgPool() = default; | ~AvgPool() = default; | ||||
| MS_DECLARE_PARENT(AvgPool, PrimitiveC); | MS_DECLARE_PARENT(AvgPool, PrimitiveC); | ||||
| void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1}, | void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1}, | ||||
| const std::string &padding = "valid", const Format &format = NCHW, | |||||
| const std::vector<int64_t> &pad = {0, 0, 0, 0}, const int64_t &round_mode = 0); | |||||
| void set_padding(const std::string &padding); | |||||
| const PadMode &pad_mode = VALID, const Format &format = NCHW, | |||||
| const std::vector<int64_t> &pad = {0, 0, 0, 0}, const RoundMode &round_mode = FLOOR); | |||||
| void set_pad_mode(const PadMode &pad_mode); | |||||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | void set_kernel_size(const std::vector<int64_t> &kernel_size); | ||||
| void set_strides(const std::vector<int64_t> &strides); | void set_strides(const std::vector<int64_t> &strides); | ||||
| void set_format(const Format &format); | void set_format(const Format &format); | ||||
| void set_pad(const std::vector<int64_t> &pad); | void set_pad(const std::vector<int64_t> &pad); | ||||
| void set_round_mode(const int64_t &round_mode); | |||||
| void set_round_mode(const RoundMode &round_mode); | |||||
| std::vector<int64_t> get_kernel_size() const; | std::vector<int64_t> get_kernel_size() const; | ||||
| std::vector<int64_t> get_strides() const; | std::vector<int64_t> get_strides() const; | ||||
| std::string get_padding() const; | |||||
| PadMode get_pad_mode() const; | |||||
| Format get_format() const; | Format get_format() const; | ||||
| std::vector<int64_t> get_pad() const; | std::vector<int64_t> get_pad() const; | ||||
| int64_t get_round_mode() const; | |||||
| RoundMode get_round_mode() const; | |||||
| }; | }; | ||||
| AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimAvgPoolPtr = std::shared_ptr<AvgPool>; | using PrimAvgPoolPtr = std::shared_ptr<AvgPool>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | #endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | ||||
| @@ -0,0 +1,140 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "ops/batch_norm.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) { | |||||
| set_is_training(is_training); | |||||
| set_epsilon(epsilon); | |||||
| set_format(format); | |||||
| set_momentum(momentum); | |||||
| } | |||||
| void BatchNorm::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||||
| void BatchNorm::set_epsilon(const float epsilon) { | |||||
| CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name()); | |||||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||||
| } | |||||
| void BatchNorm::set_format(const Format &format) { | |||||
| int64_t f = format; | |||||
| this->AddAttr(kFormat, MakeValue(f)); | |||||
| } | |||||
| void BatchNorm::set_momentum(const float momentun) { | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name()); | |||||
| this->AddAttr(kMomentum, MakeValue(momentun)); | |||||
| } | |||||
| float BatchNorm::get_momentum() const { | |||||
| auto value_ptr = GetAttr(kMomentum); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| bool BatchNorm::get_is_training() const { | |||||
| auto value_ptr = GetAttr(kIsTraining); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| float BatchNorm::get_epsilon() const { | |||||
| auto value_ptr = GetAttr(kEpsilon); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| Format BatchNorm::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| // Infer shape | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto batch_prim = primitive->cast<PrimBatchNormPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||||
| auto prim_name = batch_prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | |||||
| auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | |||||
| if (batch_prim->get_format() == NHWC) { | |||||
| input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | |||||
| } | |||||
| auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); | |||||
| auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name); | |||||
| auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name); | |||||
| auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); | |||||
| std::vector<int64_t> input_shape_norm; | |||||
| if (batch_prim->get_format() == NCHW) { | |||||
| input_shape_norm = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| } else { | |||||
| input_shape_norm.push_back(input_x[0]); | |||||
| input_shape_norm.push_back(input_x[3]); | |||||
| input_shape_norm.push_back(input_x[1]); | |||||
| input_shape_norm.push_back(input_x[2]); | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name); | |||||
| CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); | |||||
| CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, | |||||
| TypeError); | |||||
| if (!batch_prim->get_is_training()) { | |||||
| CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); | |||||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); | |||||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); | |||||
| } | |||||
| // Infer type | |||||
| auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | |||||
| std::map<std::string, TypePtr> args; | |||||
| args.emplace("scale", input_args[1]->BuildType()); | |||||
| args.emplace("bias", input_args[2]->BuildType()); | |||||
| CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); | |||||
| std::map<std::string, TypePtr> args_moving; | |||||
| args_moving.emplace("scale", input_args[2]->BuildType()); | |||||
| args_moving.emplace("bias", input_args[3]->BuildType()); | |||||
| CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); | |||||
| auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x); | |||||
| auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | |||||
| auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | |||||
| auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | |||||
| if (batch_prim->get_format() == NHWC) { | |||||
| output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | |||||
| output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | |||||
| output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | |||||
| } | |||||
| AbstractBasePtrList output = {output0, output1, output2, output3, output3}; | |||||
| return std::make_shared<abstract::AbstractTuple>(output); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -20,11 +20,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBatchNorm = "BatchNorm"; | constexpr auto kNameBatchNorm = "BatchNorm"; | ||||
| class BatchNorm : public PrimitiveC { | class BatchNorm : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -34,19 +35,23 @@ class BatchNorm : public PrimitiveC { | |||||
| } | } | ||||
| ~BatchNorm() = default; | ~BatchNorm() = default; | ||||
| MS_DECLARE_PARENT(BatchNorm, PrimitiveC); | MS_DECLARE_PARENT(BatchNorm, PrimitiveC); | ||||
| void Init(bool is_training = false, float epsilon = 1e-5, const Format &format = NCHW); | |||||
| void set_is_training(bool is_training); | |||||
| void set_epsilon(float epsilon); | |||||
| void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1, | |||||
| const Format &format = NCHW); | |||||
| void set_is_training(const bool is_training); | |||||
| void set_epsilon(const float epsilon); | |||||
| void set_format(const Format &format); | void set_format(const Format &format); | ||||
| bool get_is_trainging(); | |||||
| float get_epsilon(); | |||||
| void set_momentum(const float momentum); | |||||
| bool get_is_training() const; | |||||
| float get_epsilon() const; | |||||
| Format get_format() const; | Format get_format() const; | ||||
| float get_momentum() const; | |||||
| }; | }; | ||||
| AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimBatchNormPtr = std::shared_ptr<BatchNorm>; | using PrimBatchNormPtr = std::shared_ptr<BatchNorm>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BatchNorm_H_ | #endif // MINDSPORE_CORE_C_OPS_BatchNorm_H_ | ||||
| @@ -0,0 +1,116 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "ops/batch_norm_fold.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void BatchNormFold::Init(const float momentum, const float epsilon, const bool is_training, const int64_t freeze_bn) { | |||||
| set_momentum(momentum); | |||||
| set_epsilon(epsilon); | |||||
| set_is_training(is_training); | |||||
| set_freeze_bn(freeze_bn); | |||||
| } | |||||
| void BatchNormFold::set_momentum(const float momentum) { | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentum, kIncludeBoth, {0.0, 1.0}, this->name()); | |||||
| this->AddAttr(kMomentum, MakeValue(momentum)); | |||||
| } | |||||
| float BatchNormFold::get_momentum() const { | |||||
| auto value_ptr = GetAttr(kMomentum); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void BatchNormFold::set_epsilon(const float epsilon) { | |||||
| float match_value = 0.0; | |||||
| CheckAndConvertUtils::CheckValue(kEpsilon, epsilon, kGreaterThan, match_value, this->name()); | |||||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||||
| } | |||||
| float BatchNormFold::get_epsilon() const { | |||||
| auto value_ptr = GetAttr(kEpsilon); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| void BatchNormFold::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||||
| bool BatchNormFold::get_is_training() const { | |||||
| auto value_ptr = GetAttr(kIsTraining); | |||||
| return GetValue<bool>(value_ptr); | |||||
| } | |||||
| void BatchNormFold::set_freeze_bn(const int64_t freeze_bn) { this->AddAttr(kFreezeBn, MakeValue(freeze_bn)); } | |||||
| int64_t BatchNormFold::get_freeze_bn() const { | |||||
| auto value_ptr = GetAttr(kFreezeBn); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(BatchNormFold_prim); | |||||
| auto op_name = BatchNormFold_prim->name(); | |||||
| auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); | |||||
| auto variance_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||||
| auto global_step_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name); | |||||
| CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); | |||||
| CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); | |||||
| CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); | |||||
| auto mean_type = input_args[1]->BuildType(); | |||||
| auto variance_type = input_args[2]->BuildType(); | |||||
| auto x_type = input_args[0]->BuildType(); | |||||
| auto global_step_type = input_args[3]->BuildType(); | |||||
| std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}}; | |||||
| CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name); | |||||
| auto tensor_type0 = x_type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type0); | |||||
| auto element0 = tensor_type0->element(); | |||||
| auto tensor_type1 = mean_type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type1); | |||||
| auto element1 = tensor_type1->element(); | |||||
| auto tensor_type2 = variance_type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type2); | |||||
| auto element2 = tensor_type2->element(); | |||||
| CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name); | |||||
| CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name); | |||||
| auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape); | |||||
| AbstractBasePtrList output1 = {output, output, output, output}; | |||||
| return std::make_shared<abstract::AbstractTuple>(output1); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchNormFold, prim::kPrimBatchNormFold, BatchNormFoldInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BATCH_NORM_FOLD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BATCH_NORM_FOLD_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameBatchNormFold = "BatchNormFold"; | |||||
| class BatchNormFold : public PrimitiveC { | |||||
| public: | |||||
| BatchNormFold() : PrimitiveC(kNameBatchNormFold) { | |||||
| InitIOName({"x", "mean", "variance", "global_step"}, {"batch_mean", "batch_std", "running_mean", "running_std"}); | |||||
| } | |||||
| ~BatchNormFold() = default; | |||||
| MS_DECLARE_PARENT(BatchNormFold, PrimitiveC); | |||||
| void Init(const float momentum = 0.9, const float epsilon = 1e-5, const bool is_training = true, | |||||
| const int64_t freeze_bn = 0); | |||||
| void set_momentum(const float momentum); | |||||
| void set_epsilon(const float epsilon); | |||||
| void set_is_training(const bool is_training); | |||||
| void set_freeze_bn(const int64_t freeze_bn); | |||||
| float get_momentum() const; | |||||
| float get_epsilon() const; | |||||
| bool get_is_training() const; | |||||
| int64_t get_freeze_bn() const; | |||||
| }; | |||||
| AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBatchNormFoldPtr = std::shared_ptr<BatchNormFold>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_BATCH_NORM_FOLD_H_ | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/batch_to_space.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) { | |||||
| this->set_block_size(block_size); | |||||
| this->set_crops(crops); | |||||
| } | |||||
| void BatchToSpace::set_block_size(const std::vector<int64_t> &block_size) { | |||||
| this->AddAttr(kBlockSize, MakeValue(block_size)); | |||||
| } | |||||
| std::vector<int64_t> BatchToSpace::get_block_size() const { | |||||
| auto value_ptr = this->GetAttr(kBlockSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void BatchToSpace::set_crops(const std::vector<std::vector<int64_t>> &crops) { | |||||
| this->AddAttr(kCrops, MakeValue(crops)); | |||||
| } | |||||
| std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const { | |||||
| auto value_ptr = this->GetAttr(kCrops); | |||||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim = primitive->cast<PrimBatchToSpacePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||||
| auto block_size = prim->get_block_size(); | |||||
| auto crops = prim->get_crops(); | |||||
| auto out_shape = x_shape; | |||||
| for (size_t i = 0; i < 2; ++i) { | |||||
| auto x_block_prod = out_shape[i + 2] * block_size[i]; | |||||
| auto crops_sum = crops[i][0] + crops[i][1]; | |||||
| CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name); | |||||
| out_shape[i + 2] = x_block_prod - crops_sum; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])", | |||||
| out_shape[0] % (block_size[0] * block_size[1]), kEqual, 0, prim_name); | |||||
| out_shape[0] /= block_size[0] * block_size[1]; | |||||
| auto ret = input_args[0]->Broaden(); | |||||
| ret->set_shape(std::make_shared<abstract::Shape>(out_shape)); | |||||
| return ret; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameBatchToSpace = "BatchToSpace"; | |||||
| class BatchToSpace : public PrimitiveC { | |||||
| public: | |||||
| BatchToSpace() : PrimitiveC(kNameBatchToSpace) {} | |||||
| ~BatchToSpace() = default; | |||||
| MS_DECLARE_PARENT(BatchToSpace, PrimitiveC); | |||||
| void Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops); | |||||
| void set_block_size(const std::vector<int64_t> &block_size); | |||||
| void set_crops(const std::vector<std::vector<int64_t>> &crops); | |||||
| std::vector<int64_t> get_block_size() const; | |||||
| std::vector<std::vector<int64_t>> get_crops() const; | |||||
| }; | |||||
| AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBatchToSpacePtr = std::shared_ptr<BatchToSpace>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_H_ | |||||
| @@ -0,0 +1,112 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/batch_to_space_nd.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||||
| auto prim_name = batch_prim->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | |||||
| auto out_shape = x_shape; | |||||
| int64_t block_shape_prod = 1; | |||||
| int64_t offset = 2; | |||||
| if (x_shape.size() <= 4) { | |||||
| offset = 1; | |||||
| } | |||||
| auto block_shape = batch_prim->get_block_shape(); | |||||
| auto crops = batch_prim->get_crops(); | |||||
| int64_t size = block_shape.size(); | |||||
| for (int64_t i = 0; i < size; i++) { | |||||
| block_shape_prod = block_shape_prod * block_shape[i]; | |||||
| auto x_block_prod = out_shape[i + offset] * block_shape[i]; | |||||
| auto crops_sum = crops[i][0] + crops[i][1]; | |||||
| CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", crops_sum, prim_name); | |||||
| out_shape[i + offset] = x_block_prod - crops_sum; | |||||
| } | |||||
| if (out_shape[0] % block_shape_prod != 0) { | |||||
| MS_EXCEPTION(ValueError) << prim_name << " input_x dimension 0 " << out_shape[0] | |||||
| << " should be divisible by block_shape_prod " << block_shape_prod; | |||||
| } | |||||
| out_shape[0] = int64_t(floor(out_shape[0] / block_shape_prod)); | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto infer_type = input_args[0]->BuildType(); | |||||
| return infer_type; | |||||
| } | |||||
| } // namespace | |||||
| void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) { | |||||
| CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name()); | |||||
| int64_t h = crops.size(); | |||||
| int64_t w = crops[0].size(); | |||||
| std::vector<int64_t> temp_w = {2, 2}; | |||||
| CheckAndConvertUtils::Check(kCrops, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name()); | |||||
| for (int64_t i = 0; i < h; i++) { | |||||
| for (int64_t j = 0; j < w; j++) { | |||||
| CheckAndConvertUtils::CheckInteger(kCrops, crops[i][j], kGreaterEqual, 0, this->name()); | |||||
| } | |||||
| } | |||||
| this->AddAttr(kCrops, MakeValue(crops)); | |||||
| } | |||||
| std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const { | |||||
| auto value_ptr = GetAttr(kCrops); | |||||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||||
| } | |||||
| void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) { | |||||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name()); | |||||
| for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) { | |||||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name()); | |||||
| } | |||||
| this->AddAttr(kBlockShape, MakeValue(block_shape)); | |||||
| } | |||||
| std::vector<int64_t> BatchToSpaceND::get_block_shape() const { | |||||
| auto value_ptr = GetAttr(kBlockShape); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops) { | |||||
| this->set_crops(crops); | |||||
| this->set_block_shape(block_shape); | |||||
| } | |||||
| AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpaceND, prim::kPrimBatchToSpaceND, BatchToSpaceNDInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_ND_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_ND_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameBatchToSpaceND = "BatchToSpaceND"; | |||||
| class BatchToSpaceND : public PrimitiveC { | |||||
| public: | |||||
| BatchToSpaceND() : PrimitiveC(kNameBatchToSpaceND) {} | |||||
| ~BatchToSpaceND() = default; | |||||
| MS_DECLARE_PARENT(BatchToSpaceND, PrimitiveC); | |||||
| void Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops); | |||||
| void set_crops(std::vector<std::vector<int64_t>> crops); | |||||
| void set_block_shape(std::vector<int64_t> block_shape); | |||||
| std::vector<int64_t> get_block_shape() const; | |||||
| std::vector<std::vector<int64_t>> get_crops() const; | |||||
| }; | |||||
| AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBatchToSpaceNDPtr = std::shared_ptr<BatchToSpaceND>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_BATCH_TO_SPACE_ND_H_ | |||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/bias_add.h" | |||||
| #include <memory> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| // Add | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| // Add | |||||
| namespace { | |||||
| abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto biasadd_prim = primitive->cast<PrimBiasAddPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(biasadd_prim); | |||||
| auto prim_name = biasadd_prim->name(); | |||||
| // check | |||||
| CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); | |||||
| auto format = biasadd_prim->get_format(); | |||||
| auto x_channel = x_shape[1]; | |||||
| if (format != NCHW) { | |||||
| x_channel = x_shape[x_shape.size() - 1]; | |||||
| } | |||||
| CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, biasadd_prim->name()); | |||||
| std::vector<int64_t> out_shape = x_shape; | |||||
| return std::make_shared<abstract::Shape>(out_shape); | |||||
| } | |||||
| TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("input_x", input_args[0]->BuildType()); | |||||
| types.emplace("bias", input_args[1]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| BiasAdd::BiasAdd() : PrimitiveC(kNameBiasAdd) { InitIOName({"x", "b"}, {"output"}); } | |||||
| void BiasAdd::set_format(const Format &format) { | |||||
| int64_t f = format; | |||||
| this->AddAttr(kFormat, MakeValue(f)); | |||||
| } | |||||
| Format BiasAdd::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| void BiasAdd::Init(const Format &format) { this->set_format(format); } | |||||
| // Add | |||||
| AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(BiasAddInferType(primitive, input_args), | |||||
| BiasAddInferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| // Add | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,27 +14,38 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BIAS_ADD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BIAS_ADD_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| // Add | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBiasAdd = "BiasAdd"; | constexpr auto kNameBiasAdd = "BiasAdd"; | ||||
| class BiasAdd : public PrimitiveC { | class BiasAdd : public PrimitiveC { | ||||
| public: | public: | ||||
| BiasAdd() : PrimitiveC(kNameBiasAdd) { InitIOName({"x", "b"}, {"output"}); } | |||||
| // BiasAdd() : PrimitiveC(kNameBiasAdd) { InitIOName({"x", "b"}, {"output"}); } | |||||
| BiasAdd(); | |||||
| ~BiasAdd() = default; | ~BiasAdd() = default; | ||||
| MS_DECLARE_PARENT(BiasAdd, PrimitiveC); | MS_DECLARE_PARENT(BiasAdd, PrimitiveC); | ||||
| void Init(const Format &format = NCHW); | void Init(const Format &format = NCHW); | ||||
| void set_format(const Format &format); | void set_format(const Format &format); | ||||
| Format get_format() const; | Format get_format() const; | ||||
| }; | }; | ||||
| // Add | |||||
| AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBiasAddPtr = std::shared_ptr<BiasAdd>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_BIAS_ADD_H_ | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include "ops/binary_cross_entropy.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim); | |||||
| auto prim_name = binary_cross_entropy_prim->name(); | |||||
| CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto weight_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | |||||
| std::vector<int64_t> infer_shape; | |||||
| if (weight_shape.size() < 1) { | |||||
| CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | |||||
| } | |||||
| if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM && | |||||
| binary_cross_entropy_prim->get_reduction() != MEAN) { | |||||
| infer_shape = {x_shape.begin(), infer_shape.end()}; | |||||
| } | |||||
| return std::make_shared<abstract::Shape>(infer_shape); | |||||
| } | |||||
| TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", input_args.size(), kEqual, 3, prim->name()); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x_shape", input_args[0]->BuildType()); | |||||
| types.emplace("y_shape", input_args[1]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||||
| if (input_args[3]->BuildType() != nullptr) { | |||||
| types.emplace("x_shape", input_args[0]->BuildType()); | |||||
| types.emplace("weight_shape", input_args[2]->BuildType()); | |||||
| infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||||
| } | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| void BinaryCrossEntropy::set_reduction(const Reduction &reduction) { | |||||
| int64_t swi = reduction; | |||||
| this->AddAttr(kReduction, MakeValue(swi)); | |||||
| } | |||||
| Reduction BinaryCrossEntropy::get_reduction() const { | |||||
| auto value_ptr = GetAttr(kReduction); | |||||
| return Reduction(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| void BinaryCrossEntropy::Init(const Reduction &reduction) { this->set_reduction(reduction); } | |||||
| AbstractBasePtr BinaryCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyInferType(primitive, input_args), | |||||
| BinaryCrossEntroyInferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BinaryCrossEntropy, prim::kPrimBinaryCrossEntropy, BinaryCrossEntropyInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropy, BinaryCrossEntropy); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,25 +14,32 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBinaryCrossEntropy = "BinaryCrossEntropy"; | constexpr auto kNameBinaryCrossEntropy = "BinaryCrossEntropy"; | ||||
| class BinaryCrossEntropy : public PrimitiveC { | class BinaryCrossEntropy : public PrimitiveC { | ||||
| public: | public: | ||||
| BinaryCrossEntropy() : PrimitiveC(kNameBinaryCrossEntropy) {} | BinaryCrossEntropy() : PrimitiveC(kNameBinaryCrossEntropy) {} | ||||
| ~BinaryCrossEntropy() = default; | ~BinaryCrossEntropy() = default; | ||||
| MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); | MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); | ||||
| void Init(const std::string &reduction = "mean"); | |||||
| void set_reduction(const std::string &reduction); | |||||
| std::string get_reduction() const; | |||||
| void Init(const Reduction &reduction = MEAN); | |||||
| void set_reduction(const Reduction &reduction); | |||||
| Reduction get_reduction() const; | |||||
| }; | }; | ||||
| AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBinaryCrossEntropyPtr = std::shared_ptr<BinaryCrossEntropy>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_H_ | |||||
| @@ -14,13 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/black_box.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/black_box.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| void BlackBox::Init(const std::string &id, int64_t size, const std::vector<int64_t> &address) { | |||||
| namespace ops { | |||||
| void BlackBox::Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address) { | |||||
| this->set_id(id); | this->set_id(id); | ||||
| this->set_size(size); | this->set_size(size); | ||||
| this->set_address(address); | this->set_address(address); | ||||
| @@ -33,7 +34,7 @@ std::string BlackBox::get_id() const { | |||||
| return GetValue<std::string>(value_ptr); | return GetValue<std::string>(value_ptr); | ||||
| } | } | ||||
| void BlackBox::set_size(int64_t size) { this->AddAttr(kSize, MakeValue(size)); } | |||||
| void BlackBox::set_size(const int64_t size) { this->AddAttr(kSize, MakeValue(size)); } | |||||
| int64_t BlackBox::get_size() const { | int64_t BlackBox::get_size() const { | ||||
| auto value_ptr = this->GetAttr(kSize); | auto value_ptr = this->GetAttr(kSize); | ||||
| @@ -47,4 +48,5 @@ std::vector<int64_t> BlackBox::get_address() const { | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox); | REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,25 +14,26 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BLACK_BOX_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BLACK_BOX_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBlackBox = "BlackBox"; | constexpr auto kNameBlackBox = "BlackBox"; | ||||
| class BlackBox : public PrimitiveC { | class BlackBox : public PrimitiveC { | ||||
| public: | public: | ||||
| BlackBox() : PrimitiveC(kNameBlackBox) {} | BlackBox() : PrimitiveC(kNameBlackBox) {} | ||||
| ~BlackBox() = default; | ~BlackBox() = default; | ||||
| MS_DECLARE_PARENT(BlackBox, PrimitiveC); | MS_DECLARE_PARENT(BlackBox, PrimitiveC); | ||||
| void Init(const std::string &id, int64_t size, const std::vector<int64_t> &address); | |||||
| void Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address); | |||||
| void set_id(const std::string &id); | void set_id(const std::string &id); | ||||
| void set_size(int64_t size); | |||||
| void set_size(const int64_t size); | |||||
| void set_address(const std::vector<int64_t> &address); | void set_address(const std::vector<int64_t> &address); | ||||
| std::string get_id() const; | std::string get_id() const; | ||||
| int64_t get_size() const; | int64_t get_size() const; | ||||
| @@ -40,6 +41,7 @@ class BlackBox : public PrimitiveC { | |||||
| }; | }; | ||||
| using PrimBlackBoxPtr = std::shared_ptr<BlackBox>; | using PrimBlackBoxPtr = std::shared_ptr<BlackBox>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_BLACK_BOX_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/broadcast.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void Broadcast::Init(const int64_t root_rank, const std::string &group) { | |||||
| this->set_root_rank(root_rank); | |||||
| this->set_group(group); | |||||
| } | |||||
| void Broadcast::set_root_rank(const int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); } | |||||
| void Broadcast::set_group(const std::string &group) { | |||||
| CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name()); | |||||
| this->AddAttr(kGroup, MakeValue(group)); | |||||
| } | |||||
| int64_t Broadcast::get_root_rank() const { | |||||
| auto value_ptr = this->GetAttr(kRootRank); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| std::string Broadcast::get_group() const { | |||||
| auto value_ptr = this->GetAttr(kGroup); | |||||
| return GetValue<std::string>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto broadcast_prim = primitive->cast<PrimBroadcast>(); | |||||
| MS_EXCEPTION_IF_NULL(broadcast_prim); | |||||
| auto prim_name = broadcast_prim->name(); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| // infer shape | |||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| // infer type | |||||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| std::vector<TypePtr> output_types; | |||||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; | |||||
| for (size_t i = 0; i < input_args.size(); i++) { | |||||
| auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| output_types.push_back(out_type); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name); | |||||
| } | |||||
| return std::make_shared<abstract::AbstractTensor>(x_type, in_shape); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Broadcast, prim::kPrimBroadcast, BroadcastInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -17,24 +17,31 @@ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_ | #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_ | ||||
| #define MINDSPORE_CORE_C_OPS_BROADCAST_H_ | #define MINDSPORE_CORE_C_OPS_BROADCAST_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBroadcast = "Broadcast"; | constexpr auto kNameBroadcast = "Broadcast"; | ||||
| class Broadcast : public PrimitiveC { | class Broadcast : public PrimitiveC { | ||||
| public: | public: | ||||
| Broadcast() : PrimitiveC(kNameBroadcast) {} | Broadcast() : PrimitiveC(kNameBroadcast) {} | ||||
| ~Broadcast() = default; | ~Broadcast() = default; | ||||
| MS_DECLARE_PARENT(Broadcast, PrimitiveC); | MS_DECLARE_PARENT(Broadcast, PrimitiveC); | ||||
| void Init(int64_t root_rank, const std::string &group = "hccl_world_group"); | |||||
| void set_root_rank(int64_t root_rank); | |||||
| void Init(const int64_t root_rank, const std::string &group = "hccl_world_group"); | |||||
| void set_root_rank(const int64_t root_rank); | |||||
| void set_group(const std::string &group); | void set_group(const std::string &group); | ||||
| int64_t get_root_rank(); | |||||
| int64_t get_root_rank() const; | |||||
| std::string get_group() const; | std::string get_group() const; | ||||
| }; | }; | ||||
| AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBroadcast = std::shared_ptr<Broadcast>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_ | #endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_ | ||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include "ops/broadcast_to.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(broad_cast_to); | |||||
| auto prim_name = broad_cast_to->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto input_x = broad_cast_to->get_shape(); | |||||
| int64_t outer_dim_offset = input_x.size() - x_shape.size(); | |||||
| CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name); | |||||
| bool flag = true; | |||||
| if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) { | |||||
| flag = false; | |||||
| } else { | |||||
| flag = true; | |||||
| } | |||||
| if (flag == true) { | |||||
| for (int64_t i = 0; i < (int64_t)input_x.size(); i++) { | |||||
| if (input_x[i] == -1) { | |||||
| if (i < outer_dim_offset) { | |||||
| MS_EXCEPTION(ValueError) << " -1 in init shape is in an incompatible " | |||||
| "location with given input tensor, -1 index in init shape: " | |||||
| << i << " but -1 can only be in index" << x_shape.size() | |||||
| << "onwards for this input."; | |||||
| } | |||||
| input_x[i] = x_shape[i - outer_dim_offset]; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::reverse(input_x.begin(), input_x.end()); | |||||
| return std::make_shared<abstract::Shape>(input_x); | |||||
| } | |||||
| TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||||
| CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); | |||||
| auto infer_dtype = input_args[0]->BuildType()->type_id(); | |||||
| return TypeIdToType(infer_dtype); | |||||
| } | |||||
| } // namespace | |||||
| void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); } | |||||
| void BroadcastTo::set_shape(const std::vector<int64_t> &shape) { | |||||
| CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name()); | |||||
| AddAttr(kShape, MakeValue(shape)); | |||||
| } | |||||
| std::vector<int64_t> BroadcastTo::get_shape() const { | |||||
| auto value_ptr = GetAttr(kShape); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args), | |||||
| BroadcastToInferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -14,18 +14,19 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_TO_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_BROADCAST_TO_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameBroadcastTo = "BroadcastTo"; | constexpr auto kNameBroadcastTo = "BroadcastTo"; | ||||
| class BroadcastTo : public PrimitiveC { | class BroadcastTo : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -41,6 +42,7 @@ AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimBroadcastToPtr = std::shared_ptr<BroadcastTo>; | using PrimBroadcastToPtr = std::shared_ptr<BroadcastTo>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_BROADCAST_TO_H_ | |||||
| @@ -14,8 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/cast.h" | |||||
| #include "ops/cast.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| REGISTER_PRIMITIVE_C(kNameCast, Cast); | REGISTER_PRIMITIVE_C(kNameCast, Cast); | ||||
| } | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -19,12 +19,13 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameCast = "Cast"; | constexpr auto kNameCast = "Cast"; | ||||
| class Cast : public PrimitiveC { | class Cast : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -35,5 +36,6 @@ class Cast : public PrimitiveC { | |||||
| AbstractBasePtr CastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimCast = std::shared_ptr<Cast>; | using PrimCast = std::shared_ptr<Cast>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CAST_H_ | #endif // MINDSPORE_CORE_C_OPS_CAST_H_ | ||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ops/ceil.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil"); | |||||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||||
| auto infer_type = input_args[0]->BuildType(); | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); | |||||
| MS_EXCEPTION_IF_NULL(infer_type); | |||||
| auto tensor_type = infer_type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto data_type = tensor_type->element(); | |||||
| MS_EXCEPTION_IF_NULL(data_type); | |||||
| return std::make_shared<abstract::AbstractTensor>(data_type, x_shape); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Ceil, prim::kPrimCeil, CeilInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameCeil, Ceil); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -19,21 +19,24 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameCeil = "Ceil"; | constexpr auto kNameCeil = "Ceil"; | ||||
| class Ceil : public PrimitiveC { | class Ceil : public PrimitiveC { | ||||
| public: | public: | ||||
| Ceil() : PrimitiveC(kNameCeil) { InitIOName({"x"}, {"y"}); } | Ceil() : PrimitiveC(kNameCeil) { InitIOName({"x"}, {"y"}); } | ||||
| ~Ceil() = default; | ~Ceil() = default; | ||||
| MS_DECLARE_PARENT(Ceil, PrimitiveC); | MS_DECLARE_PARENT(Ceil, PrimitiveC); | ||||
| void init() {} | |||||
| }; | }; | ||||
| AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimCeil = std::shared_ptr<Ceil>; | |||||
| using PrimCeilPtr = std::shared_ptr<Ceil>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CEIL_H_ | #endif // MINDSPORE_CORE_C_OPS_CEIL_H_ | ||||
| @@ -14,12 +14,13 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/clip.h" | |||||
| #include "ops/clip.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| void Clip::Init(const float max, const float min) { | void Clip::Init(const float max, const float min) { | ||||
| this->set_max(max); | this->set_max(max); | ||||
| this->set_min(min); | this->set_min(min); | ||||
| @@ -39,4 +40,5 @@ float Clip::get_min() const { | |||||
| return GetValue<float>(value_ptr); | return GetValue<float>(value_ptr); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameClip, Clip); | REGISTER_PRIMITIVE_C(kNameClip, Clip); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,11 +17,12 @@ | |||||
| #define MINDSPORE_CORE_C_OPS_CLIP_H_ | #define MINDSPORE_CORE_C_OPS_CLIP_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameClip = "Clip"; | constexpr auto kNameClip = "Clip"; | ||||
| class Clip : public PrimitiveC { | class Clip : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -36,6 +37,7 @@ class Clip : public PrimitiveC { | |||||
| }; | }; | ||||
| using PrimClipPtr = std::shared_ptr<Clip>; | using PrimClipPtr = std::shared_ptr<Clip>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CLIP_H_ | #endif // MINDSPORE_CORE_C_OPS_CLIP_H_ | ||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "ops/concat.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| void Concat::Init(const int64_t axis) { this->set_axis(axis); } | |||||
| int64_t Concat::get_axis() const { | |||||
| auto value_ptr = this->GetAttr(kAxis); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||||
| AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim = primitive->cast<PrimConcatPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_tuple); | |||||
| auto elements = input_tuple->elements(); | |||||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | |||||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(element0); | |||||
| auto element0_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||||
| auto element0_rank = SizeToLong(element0_shape.size()); | |||||
| auto axis = prim->get_axis(); | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | |||||
| prim_name); | |||||
| axis = axis < 0 ? axis + element0_rank : axis; | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("element0", element0->BuildType()); | |||||
| int64_t all_shp = element0_shape[axis]; | |||||
| for (size_t i = 1; i < elements.size(); ++i) { | |||||
| std::string elementi = "element" + std::to_string(i); | |||||
| auto elementi_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | |||||
| prim_name); | |||||
| for (int64_t j = 0; j < element0_rank; ++j) { | |||||
| if (j != axis && elementi_shape[j] != element0_shape[j]) { | |||||
| MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element."; | |||||
| } | |||||
| } | |||||
| all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis]; | |||||
| types.emplace(elementi, elements[i]->BuildType()); | |||||
| } | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name); | |||||
| auto ret_shape = element0_shape; | |||||
| ret_shape[axis] = all_shp; | |||||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type), | |||||
| std::make_shared<abstract::Shape>(ret_shape)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameConcat, Concat); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -19,24 +19,26 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameConcat = "Concat"; | constexpr auto kNameConcat = "Concat"; | ||||
| class Concat : public PrimitiveC { | class Concat : public PrimitiveC { | ||||
| public: | public: | ||||
| Concat() : PrimitiveC(kNameConcat) {} | Concat() : PrimitiveC(kNameConcat) {} | ||||
| ~Concat() = default; | ~Concat() = default; | ||||
| MS_DECLARE_PARENT(Concat, PrimitiveC); | MS_DECLARE_PARENT(Concat, PrimitiveC); | ||||
| void Init(int64_t axis = 0); | |||||
| void set_axis(int64_t axis); | |||||
| void Init(const int64_t axis = 0); | |||||
| void set_axis(const int64_t axis); | |||||
| int64_t get_axis() const; | int64_t get_axis() const; | ||||
| }; | }; | ||||
| AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimConcatPtr = std::shared_ptr<Concat>; | using PrimConcatPtr = std::shared_ptr<Concat>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CONCAT_H_ | #endif // MINDSPORE_CORE_C_OPS_CONCAT_H_ | ||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/constant.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto x = input_args[0]->BuildShape(); | |||||
| auto shape_element = x->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape_element); | |||||
| return shape_element; | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| std::map<std::string, TypePtr> types; | |||||
| types.emplace("x", input_args[0]->BuildType()); | |||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||||
| return TypeIdToType(infer_type); | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Constant, prim::kPrimConstant, ConstantInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameConstant, Constant); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_CONSTANT_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_CONSTANT_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameConstant = "Constant"; | |||||
| class Constant : public PrimitiveC { | |||||
| public: | |||||
| Constant() : PrimitiveC(kNameConstant) {} | |||||
| ~Constant() = default; | |||||
| MS_DECLARE_PARENT(Constant, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimConstantPtr = std::shared_ptr<Constant>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_C_OPS_CONSTANT_H_ | |||||
| @@ -14,12 +14,30 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/constant_of_shape.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/constant_of_shape.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); | |||||
| auto input_shape = | |||||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape"); | |||||
| return std::make_shared<abstract::Shape>(input_shape); | |||||
| } | |||||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto constant_prim = primitive->cast<PrimConstantOfShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(constant_prim); | |||||
| auto data_type = TypeId(constant_prim->get_data_type()); | |||||
| return TypeIdToType(data_type); | |||||
| } | |||||
| } // namespace | |||||
| void ConstantOfShape::Init(int64_t data_type, const std::vector<float> &value) { | void ConstantOfShape::Init(int64_t data_type, const std::vector<float> &value) { | ||||
| this->set_data_type(data_type); | this->set_data_type(data_type); | ||||
| this->set_value(value); | this->set_value(value); | ||||
| @@ -38,5 +56,12 @@ std::vector<float> ConstantOfShape::get_value() const { | |||||
| auto value_ptr = this->GetAttr(kValue); | auto value_ptr = this->GetAttr(kValue); | ||||
| return GetValue<std::vector<float>>(value_ptr); | return GetValue<std::vector<float>>(value_ptr); | ||||
| } | } | ||||
| AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||||
| InferShape(primitive, input_args)->shape()); | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ConstantOfShape, prim::kPrimConstantOfShape, ConstantOfShapeInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape); | REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,15 +14,16 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||||
| #ifndef MINDSPORE_CORE_C_OPS_CONSTANT_OF_SHAPE_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_CONSTANT_OF_SHAPE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameConstantOfShape = "ConstantOfShape"; | constexpr auto kNameConstantOfShape = "ConstantOfShape"; | ||||
| class ConstantOfShape : public PrimitiveC { | class ConstantOfShape : public PrimitiveC { | ||||
| public: | public: | ||||
| @@ -36,7 +37,10 @@ class ConstantOfShape : public PrimitiveC { | |||||
| std::vector<float> get_value() const; | std::vector<float> get_value() const; | ||||
| }; | }; | ||||
| AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimConstantOfShapePtr = std::shared_ptr<ConstantOfShape>; | using PrimConstantOfShapePtr = std::shared_ptr<ConstantOfShape>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||||
| #endif // MINDSPORE_CORE_C_OPS_CONSTANT_OF_SHAPE_H_ | |||||
| @@ -14,19 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/control_depend.h" | |||||
| #include "ops/control_depend.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| void ControlDepend::Init(int64_t depend_mode) { this->set_depend_mode(depend_mode); } | |||||
| namespace ops { | |||||
| void ControlDepend::Init(const int64_t depend_mode) { this->set_depend_mode(depend_mode); } | |||||
| void ControlDepend::set_depend_mode(int64_t depend_mode) { | |||||
| CheckAndConvertUtils::CheckInRange(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name()); | |||||
| void ControlDepend::set_depend_mode(const int64_t depend_mode) { | |||||
| CheckAndConvertUtils::CheckInRange<int64_t>(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name()); | |||||
| AddAttr(kDependMode, MakeValue(depend_mode)); | AddAttr(kDependMode, MakeValue(depend_mode)); | ||||
| } | } | ||||
| int64_t ControlDepend::get_depend_mode() { | |||||
| int64_t ControlDepend::get_depend_mode() const { | |||||
| auto value_ptr = GetAttr(kDependMode); | auto value_ptr = GetAttr(kDependMode); | ||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameControlDepend, ControlDepend); | REGISTER_PRIMITIVE_C(kNameControlDepend, ControlDepend); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,24 +19,24 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/primitive_c.h" | |||||
| #include "c_ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameControlDepend = "ControlDepend"; | constexpr auto kNameControlDepend = "ControlDepend"; | ||||
| class ControlDepend : public PrimitiveC { | class ControlDepend : public PrimitiveC { | ||||
| public: | public: | ||||
| ControlDepend() : PrimitiveC(kNameControlDepend) {} | ControlDepend() : PrimitiveC(kNameControlDepend) {} | ||||
| ~ControlDepend() = default; | ~ControlDepend() = default; | ||||
| MS_DECLARE_PARENT(ControlDepend, PrimitiveC); | MS_DECLARE_PARENT(ControlDepend, PrimitiveC); | ||||
| void Init(int64_t depend_mode); | |||||
| void set_depend_mode(int64_t depend_mode); | |||||
| int64_t get_depend_mode(); | |||||
| void Init(const int64_t depend_mode); | |||||
| void set_depend_mode(const int64_t depend_mode = 0); | |||||
| int64_t get_depend_mode() const; | |||||
| }; | }; | ||||
| AbstractBasePtr ControlDependInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimControlDepend = std::shared_ptr<ControlDepend>; | using PrimControlDepend = std::shared_ptr<ControlDepend>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CONTROl_DEPEND_H_ | #endif // MINDSPORE_CORE_C_OPS_CONTROl_DEPEND_H_ | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -14,110 +14,38 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "c_ops/conv2d.h" | |||||
| #include "ops/conv2d.h" | |||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | #include <set> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/dtype/tensor_type.h" | |||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include "ops/control_depend.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| Conv2D::Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); } | |||||
| void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, | |||||
| const std::string &pad_mode, const std::vector<int64_t> &pad, const std::vector<int64_t> &stride, | |||||
| const std::vector<int64_t> &dilation, int64_t group) { | |||||
| auto prim_name = this->name(); | |||||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||||
| this->AddAttr("offset_a", MakeValue(static_cast<int64_t>(0))); | |||||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||||
| this->set_stride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); | |||||
| this->set_dilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); | |||||
| this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name); | |||||
| if (pad_mode == "pad") { | |||||
| for (auto item : pad) { | |||||
| CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name); | |||||
| } | |||||
| } else { | |||||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | |||||
| } | |||||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); | |||||
| this->set_out_channel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||||
| this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_kernel_size() const { | |||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_dilation() const { | |||||
| auto value_ptr = GetAttr(kDilation); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::string Conv2D::get_pad_mode() const { | |||||
| auto value_ptr = this->GetAttr(kPadMode); | |||||
| return GetValue<string>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_pad() const { | |||||
| auto value_ptr = this->GetAttr(kPad); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_pad_list() const { | |||||
| auto value_ptr = this->GetAttr(kPadList); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| int64_t Conv2D::get_mode() const { | |||||
| auto value_ptr = this->GetAttr(kMode); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Conv2D::get_group() const { | |||||
| auto value_ptr = this->GetAttr(kGroup); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Conv2D::get_output_channel() const { | |||||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||||
| } | |||||
| void Conv2D::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||||
| void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||||
| void Conv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||||
| void Conv2D::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||||
| void Conv2D::set_mode(int64_t mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||||
| void Conv2D::set_group(int64_t group) { this->AddAttr(kGroup, MakeValue(group)); } | |||||
| void Conv2D::set_out_channel(int64_t output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||||
| void Conv2D::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||||
| namespace ops { | |||||
| namespace { | |||||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | auto conv_prim = primitive->cast<PrimConv2dPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(conv_prim); | MS_EXCEPTION_IF_NULL(conv_prim); | ||||
| auto prim_name = conv_prim->name(); | auto prim_name = conv_prim->name(); | ||||
| CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| if (conv_prim->get_format() == NHWC) { | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | ||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | ||||
| w_shape[1], conv_prim->name()); | w_shape[1], conv_prim->name()); | ||||
| auto out_channel = conv_prim->get_output_channel(); | |||||
| auto out_channel = conv_prim->get_out_channel(); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | ||||
| std::vector<int64_t> temp_w; | std::vector<int64_t> temp_w; | ||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | ||||
| @@ -136,10 +64,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| int64_t w_out = -1; | int64_t w_out = -1; | ||||
| std::vector<int64_t> pad_list(4, 0); | std::vector<int64_t> pad_list(4, 0); | ||||
| auto pad_mode = conv_prim->get_pad_mode(); | auto pad_mode = conv_prim->get_pad_mode(); | ||||
| if (pad_mode == "valid") { | |||||
| if (pad_mode == VALID) { | |||||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | ||||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | ||||
| } else if (pad_mode == "same") { | |||||
| } else if (pad_mode == SAME) { | |||||
| h_out = ceil(x_shape[2] / stride_h); | h_out = ceil(x_shape[2] / stride_h); | ||||
| w_out = ceil(x_shape[3] / stride_w); | w_out = ceil(x_shape[3] / stride_w); | ||||
| @@ -152,7 +80,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| auto pad_left = floor(pad_needed_w / 2); | auto pad_left = floor(pad_needed_w / 2); | ||||
| pad_list.emplace_back(pad_left); | pad_list.emplace_back(pad_left); | ||||
| pad_list.emplace_back(pad_needed_h - pad_left); | pad_list.emplace_back(pad_needed_h - pad_left); | ||||
| } else if (pad_mode == "pad") { | |||||
| } else if (pad_mode == PAD) { | |||||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | ||||
| auto pad_top = conv_prim->get_pad()[0]; | auto pad_top = conv_prim->get_pad()[0]; | ||||
| auto pad_bottom = conv_prim->get_pad()[1]; | auto pad_bottom = conv_prim->get_pad()[1]; | ||||
| @@ -164,13 +92,17 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| h_out = floor(h_out); | h_out = floor(h_out); | ||||
| w_out = floor(w_out); | w_out = floor(w_out); | ||||
| } | } | ||||
| conv_prim->set_pad_list(pad_list); | |||||
| conv_prim->set_pad(pad_list); | |||||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | ||||
| if (conv_prim->get_format() == NHWC) { | |||||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | |||||
| } | |||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| } | } | ||||
| TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | ||||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||||
| for (const auto &item : input_args) { | for (const auto &item : input_args) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| } | } | ||||
| @@ -181,16 +113,126 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase | |||||
| types.emplace("w", input_args[1]->BuildType()); | types.emplace("w", input_args[1]->BuildType()); | ||||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | ||||
| if (infer_type == kNumberTypeInt8) { | if (infer_type == kNumberTypeInt8) { | ||||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||||
| return TypeIdToType(kNumberTypeInt32); | |||||
| } | } | ||||
| return TypeIdToType(infer_type); | return TypeIdToType(infer_type); | ||||
| } | } | ||||
| } // namespace | |||||
| void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode, | |||||
| const std::vector<int64_t> &pad, const std::vector<int64_t> &stride, | |||||
| const std::vector<int64_t> &dilation, int64_t group, const Format &format) { | |||||
| AddAttr(kOffsetA, MakeValue(static_cast<int64_t>(0))); | |||||
| set_kernel_size(kernel_size); | |||||
| set_stride(stride); | |||||
| set_dilation(dilation); | |||||
| set_pad(pad); | |||||
| set_pad_mode(pad_mode); | |||||
| set_mode(mode); | |||||
| set_out_channel(out_channel); | |||||
| set_group(group); | |||||
| set_format(format); | |||||
| } | |||||
| void Conv2D::set_out_channel(int64_t out_channel) { | |||||
| AddAttr(kOutChannel, | |||||
| MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name()))); | |||||
| } | |||||
| void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||||
| AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name()))); | |||||
| } | |||||
| void Conv2D::set_stride(const std::vector<int64_t> &stride) { | |||||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); | |||||
| } | |||||
| void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { | |||||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); | |||||
| } | |||||
| void Conv2D::set_pad_mode(const PadMode &pad_mode) { | |||||
| std::vector<int64_t> pad = get_pad(); | |||||
| if (pad_mode == PAD) { | |||||
| for (auto item : pad) { | |||||
| CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name()); | |||||
| } | |||||
| } else { | |||||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name()); | |||||
| } | |||||
| int64_t swi = pad_mode; | |||||
| AddAttr(kPadMode, MakeValue(swi)); | |||||
| } | |||||
| void Conv2D::set_pad(const std::vector<int64_t> &pad) { | |||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | |||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); | |||||
| } | |||||
| void Conv2D::set_mode(int64_t mode) { | |||||
| AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name()))); | |||||
| } | |||||
| void Conv2D::set_group(int64_t group) { | |||||
| AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name()))); | |||||
| } | |||||
| void Conv2D::set_format(const Format &format) { | |||||
| int64_t f = format; | |||||
| AddAttr(kFormat, MakeValue(f)); | |||||
| } | |||||
| int64_t Conv2D::get_out_channel() const { | |||||
| auto value_ptr = GetAttr(kOutChannel); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_kernel_size() const { | |||||
| auto value_ptr = GetAttr(kKernelSize); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_stride() const { | |||||
| auto value_ptr = GetAttr(kStride); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_dilation() const { | |||||
| auto value_ptr = GetAttr(kDilation); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| PadMode Conv2D::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| std::vector<int64_t> Conv2D::get_pad() const { | |||||
| auto value_ptr = GetAttr(kPad); | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||||
| } | |||||
| int64_t Conv2D::get_mode() const { | |||||
| auto value_ptr = GetAttr(kMode); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| int64_t Conv2D::get_group() const { | |||||
| auto value_ptr = GetAttr(kGroup); | |||||
| return GetValue<int64_t>(value_ptr); | |||||
| } | |||||
| Format Conv2D::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args), | ||||
| Conv2dInferShape(primitive, input_args)->shape()); | Conv2dInferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | ||||
| REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D); | REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D); | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -21,42 +21,45 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "c_ops/op_utils.h" | |||||
| #include "c_ops/primitive_c.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | |||||
| constexpr auto kNameConv2D = "Conv2D"; | constexpr auto kNameConv2D = "Conv2D"; | ||||
| class Conv2D : public PrimitiveC { | class Conv2D : public PrimitiveC { | ||||
| public: | public: | ||||
| Conv2D(); | |||||
| Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); } | |||||
| explicit Conv2D(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "w"}, {"output"}); } | |||||
| ~Conv2D() = default; | ~Conv2D() = default; | ||||
| MS_DECLARE_PARENT(Conv2D, PrimitiveC); | MS_DECLARE_PARENT(Conv2D, PrimitiveC); | ||||
| void Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1, | void Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1, | ||||
| const std::string &pad_mode = "valid", const std::vector<int64_t> &pad = {0, 0, 0, 0}, | |||||
| const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0}, | |||||
| const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1}, | const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1}, | ||||
| int64_t group = 1); | |||||
| std::vector<int64_t> get_kernel_size() const; | |||||
| std::vector<int64_t> get_stride() const; | |||||
| std::vector<int64_t> get_dilation() const; | |||||
| std::string get_pad_mode() const; | |||||
| std::vector<int64_t> get_pad() const; | |||||
| std::vector<int64_t> get_pad_list() const; | |||||
| int64_t get_mode() const; | |||||
| int64_t get_group() const; | |||||
| int64_t get_output_channel() const; | |||||
| int64_t group = 1, const Format &format = NCHW); | |||||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | void set_kernel_size(const std::vector<int64_t> &kernel_size); | ||||
| void set_stride(const std::vector<int64_t> &stride); | void set_stride(const std::vector<int64_t> &stride); | ||||
| void set_dilation(const std::vector<int64_t> &dilation); | void set_dilation(const std::vector<int64_t> &dilation); | ||||
| void set_pad_mode(const std::string &pad_mode); | |||||
| void set_pad_mode(const PadMode &pad_mode); | |||||
| void set_pad(const std::vector<int64_t> &pad); | void set_pad(const std::vector<int64_t> &pad); | ||||
| void set_mode(int64_t mode); | void set_mode(int64_t mode); | ||||
| void set_group(int64_t group); | void set_group(int64_t group); | ||||
| void set_out_channel(int64_t output_channel); | |||||
| void set_pad_list(const std::vector<int64_t> &pad_list); | |||||
| void set_out_channel(int64_t out_channel); | |||||
| void set_format(const Format &format); | |||||
| std::vector<int64_t> get_kernel_size() const; | |||||
| std::vector<int64_t> get_stride() const; | |||||
| std::vector<int64_t> get_dilation() const; | |||||
| PadMode get_pad_mode() const; | |||||
| std::vector<int64_t> get_pad() const; | |||||
| int64_t get_mode() const; | |||||
| int64_t get_group() const; | |||||
| int64_t get_out_channel() const; | |||||
| Format get_format() const; | |||||
| }; | }; | ||||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimConv2dPtr = std::shared_ptr<Conv2D>; | using PrimConv2dPtr = std::shared_ptr<Conv2D>; | ||||
| } // namespace ops | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ | #endif // MINDSPORE_CORE_C_OPS_CONV2D_H_ | ||||