From: @jinyaohui Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -333,14 +333,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| target_link_libraries(mindspore mindspore::pybind11_module) | |||
| target_link_libraries(mindspore mindspore_gvar) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) | |||
| elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| target_link_libraries(mindspore mindspore::pybind11_module) | |||
| target_link_libraries(mindspore mindspore_gvar) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| else () | |||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| @@ -717,7 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimGkDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftMax, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| #endif | |||
| @@ -20,7 +20,7 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ir/manager.h" | |||
| #include "abstract/utils.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| @@ -1750,8 +1750,8 @@ void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector< | |||
| input_abstracts.emplace_back(abstract); | |||
| } | |||
| auto prim = AnfAlgo::GetCNodePrimitive(node); | |||
| if (prim->isa<PrimitiveC>()) { | |||
| auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>(); | |||
| if (prim->isa<ops::PrimitiveC>()) { | |||
| auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>(); | |||
| MS_EXCEPTION_IF_NULL(prim_c); | |||
| auto abstract = prim_c->Infer(input_abstracts); | |||
| node->set_abstract(abstract); | |||
| @@ -8,16 +8,16 @@ endif() | |||
| message("************ build core ***************") | |||
| file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "abstract/*.cc" | |||
| "base/*.cc" | |||
| "c_ops/*.cc" | |||
| "ir/*.cc" | |||
| "utils/*.cc" | |||
| "load_mindir/*.cc" | |||
| ) | |||
| "abstract/*.cc" | |||
| "base/*.cc" | |||
| "ops/*.cc" | |||
| "ir/*.cc" | |||
| "utils/*.cc" | |||
| "load_mindir/*.cc" | |||
| ) | |||
| if(CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF") | |||
| add_compile_definitions(BUILDING_DLL) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF") | |||
| add_compile_definitions(BUILDING_DLL) | |||
| elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \ | |||
| -Wuser-defined-warnings -Winconsistent-missing-override -Wno-delete-non-abstract-non-virtual-dtor") | |||
| @@ -28,5 +28,5 @@ add_library(mindspore_core STATIC ${CORE_SRC_LIST}) | |||
| target_link_libraries(mindspore_core PRIVATE mindspore_gvar) | |||
| if(USE_GLOG) | |||
| target_link_libraries(mindspore_core PRIVATE mindspore::glog) | |||
| target_link_libraries(mindspore_core PRIVATE mindspore::glog) | |||
| endif() | |||
| @@ -91,7 +91,9 @@ inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelS | |||
| inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet"); | |||
| // Arrays | |||
| inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo"); | |||
| inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||
| inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK"); | |||
| inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar"); | |||
| inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape"); | |||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||
| @@ -99,17 +101,25 @@ inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_ | |||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||
| inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD"); | |||
| inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>(kGather); | |||
| inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather"); | |||
| inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND"); | |||
| inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2"); | |||
| inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense"); | |||
| inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape"); | |||
| inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice"); | |||
| inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup"); | |||
| inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad"); | |||
| inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||
| inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||
| inline const PrimitivePtr kPrimArgMin = std::make_shared<Primitive>("Argmin"); | |||
| inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||
| inline const PrimitivePtr kPrimUnpack = std::make_shared<Primitive>("Unpack"); | |||
| inline const PrimitivePtr kPrimUnstack = std::make_shared<Primitive>("Unstack"); | |||
| inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax"); | |||
| inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| @@ -123,6 +133,7 @@ inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("Cac | |||
| inline const PrimitivePtr kPrimDynamicAssign = std::make_shared<Primitive>("DynamicAssign"); | |||
| inline const PrimitivePtr kPrimPadAndShift = std::make_shared<Primitive>("PadAndShift"); | |||
| inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | |||
| inline const PrimitivePtr kPrimSliceFusion = std::make_shared<Primitive>("SliceFusion"); | |||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | |||
| @@ -145,16 +156,36 @@ inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("Seque | |||
| inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | |||
| inline const PrimitivePtr kPrimSpaceToBatchND = std::make_shared<Primitive>("SpaceToBatchND"); | |||
| inline const PrimitivePtr kPrimBatchToSpaceND = std::make_shared<Primitive>("BatchToSpaceND"); | |||
| inline const PrimitivePtr kPrimDepthToSpace = std::make_shared<Primitive>("DepthToSpace"); | |||
| inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("BatchToSpace"); | |||
| inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch"); | |||
| inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd"); | |||
| inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape"); | |||
| inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference"); | |||
| inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2"); | |||
| inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence"); | |||
| inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank"); | |||
| inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | |||
| inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram"); | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax"); | |||
| inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop"); | |||
| inline const PrimitivePtr kPrimFlattenGrad = std::make_shared<Primitive>("FlattenGrad"); | |||
| inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax"); | |||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropy = std::make_shared<Primitive>("SparseSoftmaxCrossEntropy"); | |||
| inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); | |||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm"); | |||
| inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan"); | |||
| inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan"); | |||
| inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin"); | |||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||
| inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||
| inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||
| inline const PrimitivePtr kPrimROIPooling = std::make_shared<Primitive>("ROIPooling"); | |||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax"); | |||
| @@ -168,6 +199,9 @@ inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("Fu | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | |||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||
| inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection"); | |||
| inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose"); | |||
| inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx"); | |||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||
| @@ -179,21 +213,34 @@ inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive> | |||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | |||
| inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | |||
| inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | |||
| inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | |||
| inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | |||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | |||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput"); | |||
| inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess"); | |||
| inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd"); | |||
| inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad"); | |||
| inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad"); | |||
| inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad"); | |||
| inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy"); | |||
| inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad"); | |||
| inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss"); | |||
| inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad"); | |||
| inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits"); | |||
| inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SigmoidCrossEntropyWithLogits"); | |||
| inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogitsGrad = | |||
| std::make_shared<Primitive>("SigmoidCrossEntropyWithLogitsGrad"); | |||
| inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | |||
| inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | |||
| inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | |||
| inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | |||
| inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn"); | |||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | |||
| inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||
| inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||
| @@ -204,18 +251,22 @@ inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout"); | |||
| inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal"); | |||
| inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal"); | |||
| inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||
| inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>("Gelu"); | |||
| inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||
| inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||
| inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu"); | |||
| inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad"); | |||
| inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||
| inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu"); | |||
| inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | |||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | |||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | |||
| inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | |||
| inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||
| inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); | |||
| inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); | |||
| inline const PrimitivePtr kPrimFakeQuantWithMinMaxVars = std::make_shared<Primitive>("FakeQuantWithMinMaxVars"); | |||
| inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp"); | |||
| inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl"); | |||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | |||
| @@ -224,6 +275,8 @@ inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive | |||
| inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||
| inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum"); | |||
| inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove"); | |||
| inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize"); | |||
| inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitive>("CustomExtractFeatures"); | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| @@ -239,6 +292,12 @@ inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGathe | |||
| inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | |||
| inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async"); | |||
| inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill"); | |||
| // Quant ops | |||
| inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold"); | |||
| inline const PrimitivePtr kPrimFakeQuantWithMinMaxVarsPerChannel = | |||
| std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel"); | |||
| // Control ops | |||
| inline const PrimitivePtr kPrimMerge = std::make_shared<Primitive>("Merge"); | |||
| // RowTensor | |||
| inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | |||
| inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues"); | |||
| @@ -251,12 +310,22 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv | |||
| inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices"); | |||
| inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape"); | |||
| // TensorList | |||
| inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor"); | |||
| inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve"); | |||
| inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Primitive>("TensorListStack"); | |||
| inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem"); | |||
| // Maths | |||
| inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil"); | |||
| inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); | |||
| inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add"); | |||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||
| inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag"); | |||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||
| inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad"); | |||
| inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce"); | |||
| inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean"); | |||
| inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum"); | |||
| inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll"); | |||
| @@ -264,6 +333,8 @@ inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAn | |||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||
| inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin"); | |||
| inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos"); | |||
| inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub"); | |||
| inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); | |||
| inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div"); | |||
| @@ -279,6 +350,7 @@ inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscala | |||
| inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||
| inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | |||
| inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power"); | |||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | |||
| inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv"); | |||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | |||
| @@ -292,12 +364,13 @@ inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | |||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | |||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | |||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | |||
| inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression"); | |||
| inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | |||
| inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference"); | |||
| inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin"); | |||
| inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos"); | |||
| inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad"); | |||
| inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"); | |||
| inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); | |||
| inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | |||
| // Statements | |||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||
| @@ -323,6 +396,7 @@ inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index"); | |||
| // Debug ops | |||
| inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert"); | |||
| inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary"); | |||
| inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||
| @@ -349,6 +423,13 @@ inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); | |||
| inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); | |||
| inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat"); | |||
| inline const PrimitivePtr kPrimLshProjection = std::make_shared<Primitive>("LshProjection"); | |||
| inline const PrimitivePtr kPrimHashtableLookup = std::make_shared<Primitive>("HashtableLookup"); | |||
| inline const PrimitivePtr kPrimCustomPredict = std::make_shared<Primitive>("CustomPredict"); | |||
| inline const PrimitivePtr kPrimStack = std::make_shared<Primitive>("Stack"); | |||
| inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox"); | |||
| inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast"); | |||
| inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list"); | |||
| @@ -371,7 +452,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_ | |||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| // Other primitive not used by backend but used in core; | |||
| // Other primitve not used by backend but used in core; | |||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | |||
| @@ -382,6 +463,44 @@ inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict | |||
| // GraphKernel ops | |||
| inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign"); | |||
| // Only used in lite | |||
| inline const PrimitivePtr kPrimLeakyRelu = std::make_shared<Primitive>("LeakyRelu"); | |||
| inline const PrimitivePtr kPrimConstant = std::make_shared<Primitive>("Constant"); | |||
| inline const PrimitivePtr kPrimLocalResponseNormalization = std::make_shared<Primitive>("LocalResponseNormalization"); | |||
| inline const PrimitivePtr kPrimFftReal = std::make_shared<Primitive>("FftReal"); | |||
| inline const PrimitivePtr kPrimMfcc = std::make_shared<Primitive>("Mfcc"); | |||
| inline const PrimitivePtr kPrimRfft = std::make_shared<Primitive>("Rfft"); | |||
| inline const PrimitivePtr kPrimFftImag = std::make_shared<Primitive>("FftImag"); | |||
| inline const PrimitivePtr kPrimSkipGram = std::make_shared<Primitive>("SkipGram"); | |||
| inline const PrimitivePtr kPrimConv2DFusion = std::make_shared<Primitive>("Conv2DFusion"); | |||
| inline const PrimitivePtr kPrimConv2dTransposeFusion = std::make_shared<Primitive>("Conv2dTransposeFusion"); | |||
| inline const PrimitivePtr kPrimDepthWiseConv2DFusion = std::make_shared<Primitive>("DepthWiseConv2DFusion"); | |||
| inline const PrimitivePtr kPrimAddFusion = std::make_shared<Primitive>("AddFusion"); | |||
| inline const PrimitivePtr kPrimScaleFusion = std::make_shared<Primitive>("ScaleFusion"); | |||
| inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusion"); | |||
| inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion"); | |||
| inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid"); | |||
| inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip"); | |||
| inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh"); | |||
| inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion = | |||
| std::make_shared<Primitive>("DepthWiseConv2DTransposeFusion"); | |||
| inline const PrimitivePtr kPrimArgMinFusion = std::make_shared<Primitive>("ArgMinFusion"); | |||
| inline const PrimitivePtr kPrimArgMaxFusion = std::make_shared<Primitive>("ArgMaxFusion"); | |||
| inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("SpaceToDepth"); | |||
| inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion"); | |||
| inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion"); | |||
| inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize"); | |||
| inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose"); | |||
| inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue"); | |||
| inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If"); | |||
| inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion"); | |||
| inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion"); | |||
| inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation"); | |||
| inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion"); | |||
| inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); | |||
| inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); | |||
| inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion"); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||
| @@ -1,19 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/abs.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameAbs, Abs); | |||
| } // namespace mindspore | |||
| @@ -1,51 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/apply_momentum.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| void ApplyMomentum::Init(bool use_nesterov, bool use_locking, float gradient_scale) { | |||
| this->set_use_nesterov(use_nesterov); | |||
| this->set_use_locking(use_locking); | |||
| this->set_gradient_scale(gradient_scale); | |||
| } | |||
| void ApplyMomentum::set_use_nesterov(bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||
| void ApplyMomentum::set_use_locking(bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||
| void ApplyMomentum::set_gradient_scale(float gradient_scale) { | |||
| this->AddAttr(kGradientScale, MakeValue(gradient_scale)); | |||
| } | |||
| bool ApplyMomentum::get_use_nesterov() const { | |||
| auto value_ptr = GetAttr(kUseNesterov); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| bool ApplyMomentum::get_use_locking() const { | |||
| auto value_ptr = GetAttr(kUseLocking); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| float ApplyMomentum::get_gradient_scale() { | |||
| auto value_ptr = GetAttr(kGradientScale); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/assign_add.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/atan.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameAtan, Atan); | |||
| } // namespace mindspore | |||
| @@ -1,54 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/audio_spectrogram.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void AudioSpectrogram::set_window_size(const int64_t &window_size) { | |||
| this->AddAttr(kWindowSize, MakeValue(window_size)); | |||
| } | |||
| int64_t AudioSpectrogram::get_window_size() const { | |||
| auto value_ptr = GetAttr(kWindowSize); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void AudioSpectrogram::set_stride(const int64_t &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| int64_t AudioSpectrogram::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void AudioSpectrogram::set_mag_square(const bool &mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); } | |||
| bool AudioSpectrogram::get_mag_square() const { | |||
| auto value_ptr = GetAttr(kMagSquare); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void AudioSpectrogram::Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square) { | |||
| this->set_window_size(window_size); | |||
| this->set_stride(stride); | |||
| this->set_mag_square(mag_square); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram); | |||
| } // namespace mindspore | |||
| @@ -1,59 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/batch_norm.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| void BatchNorm::Init(bool is_training, float epsilon, const Format &format) { | |||
| set_is_training(is_training); | |||
| set_epsilon(epsilon); | |||
| set_format(format); | |||
| } | |||
| void BatchNorm::set_is_training(bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||
| void BatchNorm::set_epsilon(float epsilon) { | |||
| CheckAndConvertUtils::CheckInRange(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name()); | |||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||
| } | |||
| void BatchNorm::set_format(const Format &format) { | |||
| int64_t f = format; | |||
| this->AddAttr(kFormat, MakeValue(f)); | |||
| } | |||
| bool BatchNorm::get_is_trainging() { | |||
| auto value_ptr = GetAttr(kIsTraining); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| float BatchNorm::get_epsilon() { | |||
| auto value_ptr = GetAttr(kEpsilon); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| Format BatchNorm::get_format() const { | |||
| auto value_ptr = GetAttr(kFormat); | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/batch_norm_fold.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold); | |||
| } // namespace mindspore | |||
| @@ -1,31 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/binary_cross_entropy_grad.h" | |||
| namespace mindspore { | |||
| void BinaryCrossEntropyGrad::Init(const std::string &reduction) { set_reduction(reduction); } | |||
| void BinaryCrossEntropyGrad::set_reduction(const std::string &reduction) { | |||
| CheckAndConvertUtils::CheckString(kReduction, reduction, {"none", "mean", "sum"}, name()); | |||
| this->AddAttr(kReduction, MakeValue(reduction)); | |||
| } | |||
| std::string BinaryCrossEntropyGrad::get_reduction() const { | |||
| auto value_ptr = GetAttr(kReduction); | |||
| return GetValue<std::string>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropyGrad, BinaryCrossEntropyGrad); | |||
| } // namespace mindspore | |||
| @@ -1,42 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/broadcast.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| void Broadcast::Init(int64_t root_rank, const std::string &group) { | |||
| this->set_root_rank(root_rank); | |||
| this->set_group(group); | |||
| } | |||
| void Broadcast::set_root_rank(int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); } | |||
| void Broadcast::set_group(const std::string &group) { | |||
| CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name()); | |||
| this->AddAttr(kGroup, MakeValue(group)); | |||
| } | |||
| int64_t Broadcast::get_root_rank() { | |||
| auto value_ptr = this->GetAttr(kRootRank); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| std::string Broadcast::get_group() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<std::string>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast); | |||
| } // namespace mindspore | |||
| @@ -1,33 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/broadcast_to.h" | |||
| namespace mindspore { | |||
| void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); } | |||
| void BroadcastTo::set_shape(const std::vector<int64_t> &shape) { | |||
| CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name()); | |||
| CheckAndConvertUtils::CheckPositiveVector(kShape, shape, name(), false, true); | |||
| AddAttr(kShape, MakeValue(shape)); | |||
| } | |||
| std::vector<int64_t> BroadcastTo::get_shape() const { | |||
| auto value_ptr = GetAttr(kShape); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/ceil.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameCeil, Ceil); | |||
| } | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/cos.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameCos, Cos); | |||
| } | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/custom_predict.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void CustomPredict::Init(int64_t outputNum, float weight_threshold) { | |||
| this->set_outputNum(outputNum); | |||
| this->set_weight_threshold(weight_threshold); | |||
| } | |||
| void CustomPredict::set_outputNum(int64_t outputNum) { this->AddAttr(kOutputNum, MakeValue(outputNum)); } | |||
| int64_t CustomPredict::get_outputNum() const { | |||
| auto value_ptr = this->GetAttr(kOutputNum); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void CustomPredict::set_weight_threshold(float weight_threshold) { | |||
| this->AddAttr(kWeightThreshold, MakeValue(weight_threshold)); | |||
| } | |||
| float CustomPredict::get_weight_threshold() const { | |||
| auto value_ptr = this->GetAttr(kWeightThreshold); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameCustomPredict, CustomPredict); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/div.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameDiv, Div); | |||
| } // namespace mindspore | |||
| @@ -1,20 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/equal.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameEqual, Equal); | |||
| } // namespace mindspore | |||
| @@ -1,20 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/exp.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameExp, Exp); | |||
| } // namespace mindspore | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/fake_quant_with_min_max_vars.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void FakeQuantWithMinMaxVars::Init(const bool &narrow_range, int64_t num_bits) { | |||
| this->set_narrow_range(narrow_range); | |||
| this->set_num_bits(num_bits); | |||
| } | |||
| void FakeQuantWithMinMaxVars::set_narrow_range(const bool &narrow_range) { | |||
| this->AddAttr(kNarrowRange, MakeValue(narrow_range)); | |||
| } | |||
| bool FakeQuantWithMinMaxVars::get_narrow_range() const { | |||
| auto value_ptr = this->GetAttr(kNarrowRange); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void FakeQuantWithMinMaxVars::set_num_bits(int64_t num_bits) { this->AddAttr(kNumBits, MakeValue(num_bits)); } | |||
| int64_t FakeQuantWithMinMaxVars::get_num_bits() const { | |||
| auto value_ptr = this->GetAttr(kNumBits); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameFakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars); | |||
| } // namespace mindspore | |||
| @@ -1,22 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/fft_imag.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameFftImag, FftImag); | |||
| } | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/flatten_grad.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameFlattenGrad, FlattenGrad); | |||
| } | |||
| @@ -1,22 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/hashtable_lookup.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameHashtableLookup, HashtableLookup); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/less.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameLess, Less); | |||
| } | |||
| @@ -1,20 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/less_equal.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameLessEqual, LessEqual); | |||
| } // namespace mindspore | |||
| @@ -1,65 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/local_response_normalization.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void LocalResponseNormalization::set_depth_radius(const int64_t &depth_radius) { | |||
| this->AddAttr(kDepthRadius, MakeValue(depth_radius)); | |||
| } | |||
| int64_t LocalResponseNormalization::get_depth_radius() const { | |||
| auto value_ptr = GetAttr(kDepthRadius); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void LocalResponseNormalization::set_bias(const float &bias) { this->AddAttr(kBias, MakeValue(bias)); } | |||
| float LocalResponseNormalization::get_bias() const { | |||
| auto value_ptr = GetAttr(kBias); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void LocalResponseNormalization::set_alpha(const float &alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } | |||
| float LocalResponseNormalization::get_alpha() const { | |||
| auto value_ptr = GetAttr(kAlpha); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void LocalResponseNormalization::set_beta(const float &beta) { this->AddAttr(kBeta, MakeValue(beta)); } | |||
| float LocalResponseNormalization::get_beta() const { | |||
| auto value_ptr = GetAttr(kBeta); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void LocalResponseNormalization::Init(const int64_t &depth_radius, const float &bias, const float &alpha, | |||
| const float &beta) { | |||
| this->set_depth_radius(depth_radius); | |||
| this->set_bias(bias); | |||
| this->set_alpha(alpha); | |||
| this->set_beta(beta); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameLocalResponseNormalization, LocalResponseNormalization); | |||
| } // namespace mindspore | |||
| @@ -1,21 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/log.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameLog, Log); | |||
| } | |||
| @@ -1,20 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/logical_not.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameLogicalNot, LogicalNot); | |||
| } // namespace mindspore | |||
| @@ -1,20 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/logical_or.h" | |||
| namespace mindspore { | |||
| REGISTER_PRIMITIVE_C(kNameLogicalOr, LogicalOr); | |||
| } // namespace mindspore | |||
| @@ -1,82 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/lstm.h" | |||
| namespace mindspore { | |||
| void LSTM::set_input_size(const int64_t &input_size) { | |||
| CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); | |||
| AddAttr(kInput_size, MakeValue(input_size)); | |||
| } | |||
| int64_t LSTM::get_input_size() const { | |||
| auto value_ptr = this->GetAttr(kInput_size); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void LSTM::set_hidden_size(const int64_t &hidden_size) { | |||
| CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); | |||
| AddAttr(kHidden_size, MakeValue(hidden_size)); | |||
| } | |||
| int64_t LSTM::get_hidden_size() const { | |||
| auto value_ptr = this->GetAttr(kHidden_size); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void LSTM::set_num_layers(const int64_t &num_layers) { | |||
| CheckAndConvertUtils::CheckInteger(kNum_layers, num_layers, kGreaterThan, 0, this->name()); | |||
| AddAttr(kNum_layers, MakeValue(kNum_layers)); | |||
| } | |||
| int64_t LSTM::get_num_layers() const { | |||
| auto value_ptr = this->GetAttr(kNum_layers); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void LSTM::set_has_bias(const bool &has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } | |||
| bool LSTM::get_has_bias() const { | |||
| auto value_ptr = this->GetAttr(kHasBias); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void LSTM::set_dropout(const float &dropout) { | |||
| CheckAndConvertUtils::CheckInRange(kDropout, dropout, kIncludeBoth, {0, 1}, this->name()); | |||
| AddAttr(kDropout, MakeValue(dropout)); | |||
| } | |||
| float LSTM::get_dropout() const { | |||
| auto value_ptr = this->GetAttr(kDropout); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void LSTM::set_bidirectional(const bool &bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); } | |||
| bool LSTM::get_bidirectional() const { | |||
| auto value_ptr = this->GetAttr(kBidirectional); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void LSTM::set_num_directions(const int64_t &num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } | |||
| int64_t LSTM::get_num_directions() const { | |||
| auto value_ptr = this->GetAttr(kNumDirections); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void LSTM::Init(const int64_t &input_size, const int64_t &hidden_size, const int64_t &num_layers, const bool &has_bias, | |||
| const float &dropout, const bool &bidirectional) { | |||
| this->set_input_size(input_size); | |||
| this->set_hidden_size(hidden_size); | |||
| this->set_num_layers(num_layers); | |||
| this->set_has_bias(has_bias); | |||
| this->set_dropout(dropout); | |||
| this->set_bidirectional(bidirectional); | |||
| if (bidirectional) { | |||
| this->set_num_directions(2); | |||
| } else { | |||
| this->set_num_directions(1); | |||
| } | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameLSTM, LSTM); | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,7 @@ | |||
| #include <utility> | |||
| #include "ir/tensor.h" | |||
| #include "ir/param_info.h" | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/shape_utils.h" | |||
| @@ -676,7 +676,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc | |||
| const std::string &node_type = node_proto.op_type(); | |||
| std::shared_ptr<Primitive> prim; | |||
| auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap(); | |||
| auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); | |||
| if (op_primc_fns.find(node_type) != op_primc_fns.end()) { | |||
| prim = op_primc_fns[node_type](); | |||
| } else { | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/abs.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto abs_prim = primitive->cast<PrimAbsPtr>(); | |||
| MS_EXCEPTION_IF_NULL(abs_prim); | |||
| auto prim_name = abs_prim->name(); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| return std::make_shared<abstract::Shape>(in_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("input_x", input_args[0]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAbs, Abs); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ABS_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ABS_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_ABS_H_ | |||
| #define MINDSPORE_CORE_OPS_ABS_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAbs = "Abs"; | |||
| class Abs : public PrimitiveC { | |||
| public: | |||
| @@ -29,6 +33,10 @@ class Abs : public PrimitiveC { | |||
| MS_DECLARE_PARENT(Abs, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAbsPtr = std::shared_ptr<Abs>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ABS_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ABS_H_ | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/adam.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto Adam_prim = primitive->cast<PrimAdamPtr>(); | |||
| MS_EXCEPTION_IF_NULL(Adam_prim); | |||
| auto prim_name = Adam_prim->name(); | |||
| // infer shape | |||
| auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name); | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name); | |||
| auto grad_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); | |||
| CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); | |||
| // infer type | |||
| auto var_type = input_args[0]->BuildType(); | |||
| auto m_type = input_args[1]->BuildType(); | |||
| auto v_type = input_args[2]->BuildType(); | |||
| auto grad_type = input_args[9]->BuildType(); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name); | |||
| auto infer_var_type = var_type->cast<TensorTypePtr>()->element(); | |||
| auto infer_m_type = m_type->cast<TensorTypePtr>()->element(); | |||
| auto infer_v_type = v_type->cast<TensorTypePtr>()->element(); | |||
| // auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element(); | |||
| auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape); | |||
| auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape); | |||
| auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape); | |||
| AbstractBasePtrList output = {output0, output1, output2}; | |||
| return std::make_shared<abstract::AbstractTuple>(output); | |||
| } | |||
| } // namespace | |||
| void Adam::Init(const bool use_locking, const bool use_nesterov) { | |||
| this->set_use_locking(use_locking); | |||
| this->set_use_nesterov(use_nesterov); | |||
| } | |||
| void Adam::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||
| void Adam::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||
| bool Adam::get_use_locking() const { | |||
| auto value_ptr = GetAttr(kUseLocking); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| bool Adam::get_use_nesterov() const { | |||
| auto value_ptr = GetAttr(kUseNesterov); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAdam, Adam); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,29 +14,34 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ADAM_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ADAM_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_ADAM_H_ | |||
| #define MINDSPORE_CORE_OPS_ADAM_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAdam = "Adam"; | |||
| class Adam : public PrimitiveC { | |||
| public: | |||
| Adam() : PrimitiveC(kNameAdam) {} | |||
| ~Adam() = default; | |||
| MS_DECLARE_PARENT(Adam, PrimitiveC); | |||
| void Init(const bool &use_locking = false, const bool &use_nesteroy = false); | |||
| void set_use_locking(const bool &use_locking); | |||
| void set_use_nesteroy(const bool &use_nesteroy); | |||
| void Init(const bool use_locking = false, const bool use_nesterov = false); | |||
| void set_use_locking(const bool use_locking); | |||
| void set_use_nesterov(const bool use_nesterov); | |||
| bool get_use_locking() const; | |||
| bool get_use_nesteroy() const; | |||
| bool get_use_nesterov() const; | |||
| }; | |||
| AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAdamPtr = std::shared_ptr<Adam>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ADAM_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ADAM_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/add.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto add_prim = primitive->cast<PrimAddPtr>(); | |||
| MS_EXCEPTION_IF_NULL(add_prim); | |||
| auto prim_name = add_prim->name(); | |||
| return BroadCastInferShape(prim_name, input_args); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| types.emplace("y", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,21 +14,23 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_ADD_H_ | |||
| #define MINDSPORE_CORE_OPS_ADD_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAdd = "Add"; | |||
| class Add : public PrimitiveC { | |||
| public: | |||
| Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); } | |||
| explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); } | |||
| ~Add() = default; | |||
| MS_DECLARE_PARENT(Add, PrimitiveC); | |||
| void Init() {} | |||
| @@ -37,6 +39,7 @@ class Add : public PrimitiveC { | |||
| AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAddPtr = std::shared_ptr<Add>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ADD_H_ | |||
| @@ -14,8 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/add_fold.h" | |||
| #include "ops/add_fold.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| REGISTER_PRIMITIVE_C(kNameAddFold, AddFold); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,17 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_ADD_FOLD_H_ | |||
| #define MINDSPORE_CORE_OPS_ADD_FOLD_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAddFold = "AddFold"; | |||
| class AddFold : public PrimitiveC { | |||
| public: | |||
| @@ -33,6 +34,7 @@ class AddFold : public PrimitiveC { | |||
| MS_DECLARE_PARENT(AddFold, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ADDFOLD_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ADD_FOLD_H_ | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/adder.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size, | |||
| const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list, | |||
| const std::vector<int64_t> &dilation, const int64_t group, const Format &format) { | |||
| set_in_channel(in_channel); | |||
| set_out_channel(out_channel); | |||
| set_kernel_size(kernel_size); | |||
| set_pad_mode(pad_mode); | |||
| set_stride(stride); | |||
| set_pad_list(pad_list); | |||
| set_dilation(dilation); | |||
| set_group(group); | |||
| set_format(format); | |||
| } | |||
| void Adder::set_in_channel(const int64_t in_channel) { this->AddAttr(kInChannel, MakeValue(in_channel)); } | |||
| int64_t Adder::get_in_channel() const { | |||
| auto value_ptr = GetAttr(kInChannel); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void Adder::set_out_channel(const int64_t out_channel) { this->AddAttr(kOutChannel, MakeValue(out_channel)); } | |||
| int64_t Adder::get_out_channel() const { | |||
| auto value_ptr = GetAttr(kOutChannel); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||
| } | |||
| std::vector<int64_t> Adder::get_kernel_size() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void Adder::set_pad_mode(const PadMode &pad_mode) { | |||
| int64_t swi = pad_mode; | |||
| this->AddAttr(kPadMode, MakeValue(swi)); | |||
| } | |||
| PadMode Adder::get_pad_mode() const { | |||
| auto value_ptr = GetAttr(kPadMode); | |||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void Adder::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| std::vector<int64_t> Adder::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| std::vector<int64_t> Adder::get_pad_list() const { | |||
| auto value_ptr = GetAttr(kPadList); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void Adder::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| std::vector<int64_t> Adder::get_dilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void Adder::set_group(const int64_t group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| int64_t Adder::get_group() const { | |||
| auto value_ptr = GetAttr(kGroup); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void Adder::set_format(const Format &format) { | |||
| int64_t swi = format; | |||
| this->AddAttr(kFormat, MakeValue(swi)); | |||
| } | |||
| Format Adder::get_format() const { | |||
| auto value_ptr = GetAttr(kFormat); | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameAdder, Adder); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ADDER_H_ | |||
| #define MINDSPORE_CORE_OPS_ADDER_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAdder = "Adder"; | |||
| class Adder : public PrimitiveC { | |||
| public: | |||
| explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {} | |||
| ~Adder() = default; | |||
| MS_DECLARE_PARENT(Adder, PrimitiveC); | |||
| void Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size, | |||
| const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list, | |||
| const std::vector<int64_t> &dilation, const int64_t group, const Format &format); | |||
| void set_in_channel(const int64_t in_channel); | |||
| void set_out_channel(const int64_t out_channel); | |||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | |||
| void set_pad_mode(const PadMode &pad_mode); | |||
| void set_stride(const std::vector<int64_t> &stride); | |||
| void set_pad_list(const std::vector<int64_t> &pad_list); | |||
| void set_dilation(const std::vector<int64_t> &dilation); | |||
| void set_group(const int64_t group); | |||
| void set_format(const Format &format); | |||
| int64_t get_in_channel() const; | |||
| int64_t get_out_channel() const; | |||
| std::vector<int64_t> get_kernel_size() const; | |||
| PadMode get_pad_mode() const; | |||
| std::vector<int64_t> get_stride() const; | |||
| std::vector<int64_t> get_pad_list() const; | |||
| std::vector<int64_t> get_dilation() const; | |||
| int64_t get_group() const; | |||
| Format get_format() const; | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ADDER_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "ops/addn.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tuple); | |||
| auto elements = input_tuple->elements(); | |||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | |||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(element0); | |||
| auto element0_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("element0", element0->BuildType()); | |||
| for (size_t i = 1; i < elements.size(); ++i) { | |||
| std::string elementi = "element" + std::to_string(i); | |||
| auto elementi_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | |||
| prim_name); | |||
| for (size_t j = 0; j < element0_shape.size(); ++j) { | |||
| if (elementi_shape[j] != element0_shape[j]) { | |||
| MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element."; | |||
| } | |||
| } | |||
| types.emplace(elementi, elements[i]->BuildType()); | |||
| } | |||
| std::set<TypeId> valid_types = common_valid_types; | |||
| valid_types.insert(kNumberTypeBool); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type), | |||
| std::make_shared<abstract::Shape>(element0_shape)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAddN, AddN); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,16 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ADDN_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ADDN_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_ADDN_H_ | |||
| #define MINDSPORE_CORE_OPS_ADDN_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAddN = "AddN"; | |||
| class AddN : public PrimitiveC { | |||
| public: | |||
| @@ -29,6 +32,10 @@ class AddN : public PrimitiveC { | |||
| MS_DECLARE_PARENT(AddN, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAddNPtr = std::shared_ptr<AddN>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ADDN_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ADDN_H_ | |||
| @@ -14,19 +14,20 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/concat.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/all.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); } | |||
| void Concat::Init(int64_t axis) { this->set_axis(axis); } | |||
| int64_t Concat::get_axis() const { | |||
| auto value_ptr = this->GetAttr(kAxis); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void All::set_keep_dims(const int64_t keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); } | |||
| void Concat::set_axis(int64_t axis) { | |||
| this->AddAttr(kAxis, MakeValue(CheckAndConvertUtils::CheckInteger(kAxis, axis, kGreaterEqual, 0, this->name()))); | |||
| int64_t All::get_keep_dims() const { | |||
| auto value_ptr = GetAttr(kKeepDims); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameConcat, Concat); | |||
| REGISTER_PRIMITIVE_C(kNameAll, All); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ALL_H_ | |||
| #define MINDSPORE_CORE_OPS_ALL_H_ | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAll = "All"; | |||
| class All : public PrimitiveC { | |||
| public: | |||
| All() : PrimitiveC(kNameAll) {} | |||
| ~All() = default; | |||
| MS_DECLARE_PARENT(All, PrimitiveC); | |||
| void Init(const int64_t keep_dims); | |||
| void set_keep_dims(const int64_t keep_dims); | |||
| int64_t get_keep_dims() const; | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ALL_H_ | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <string> | |||
| #include "ops/apply_momentum.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) { | |||
| this->set_use_nesterov(use_nesterov); | |||
| this->set_use_locking(use_locking); | |||
| this->set_gradient_scale(gradient_scale); | |||
| } | |||
| void ApplyMomentum::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); } | |||
| void ApplyMomentum::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); } | |||
| void ApplyMomentum::set_gradient_scale(const float gradient_scale) { | |||
| this->AddAttr(kGradientScale, MakeValue(gradient_scale)); | |||
| } | |||
| bool ApplyMomentum::get_use_nesterov() const { | |||
| auto value_ptr = GetAttr(kUseNesterov); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| bool ApplyMomentum::get_use_locking() const { | |||
| auto value_ptr = GetAttr(kUseLocking); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| float ApplyMomentum::get_gradient_scale() const { | |||
| auto value_ptr = GetAttr(kGradientScale); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>(); | |||
| MS_EXCEPTION_IF_NULL(momentum_prim); | |||
| auto prim_name = momentum_prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | |||
| // Infer shape | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name); | |||
| // Infer type | |||
| auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name); | |||
| const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32), | |||
| TypeIdToType(kNumberTypeFloat64)}; | |||
| std::map<std::string, TypePtr> args; | |||
| args.insert({"l_type", l_type}); | |||
| args.insert({"g_type", g_type}); | |||
| args.insert({"m_type", m_type}); | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(g_type, v_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer); | |||
| REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||
| #define MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_ | |||
| #define MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameApplyMomentum = "ApplyMomentum"; | |||
| class ApplyMomentum : public PrimitiveC { | |||
| public: | |||
| @@ -29,14 +33,18 @@ class ApplyMomentum : public PrimitiveC { | |||
| } | |||
| ~ApplyMomentum() = default; | |||
| MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); | |||
| void Init(bool use_nesterov, bool use_locking, float gradient_scale); | |||
| void set_use_nesterov(bool use_nesterov); | |||
| void set_use_locking(bool use_locking); | |||
| void set_gradient_scale(float gradient_scale); | |||
| void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0); | |||
| void set_use_nesterov(const bool use_nesterov); | |||
| void set_use_locking(const bool use_locking); | |||
| void set_gradient_scale(const float gradient_scale); | |||
| bool get_use_nesterov() const; | |||
| bool get_use_locking() const; | |||
| float get_gradient_scale(); | |||
| float get_gradient_scale() const; | |||
| }; | |||
| AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_ | |||
| #endif // MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_ | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/arg_max.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| auto prim = primitive->cast<PrimArgMaxPtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto axis = prim->get_axis(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto x_rank = SizeToLong(x_shape.size()); | |||
| CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | |||
| axis = axis < 0 ? axis + x_rank : axis; | |||
| std::vector<int64_t> out_shape; | |||
| for (size_t i = 0; i < x_shape.size(); ++i) { | |||
| if (SizeToLong(i) != axis) { | |||
| out_shape.emplace_back(x_shape[i]); | |||
| } | |||
| } | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| return kInt32; | |||
| } | |||
| } // namespace | |||
| void ArgMax::Init(const int64_t axis, const TypeId output_type) { | |||
| set_axis(axis); | |||
| set_output_type(output_type); | |||
| } | |||
| void ArgMax::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||
| void ArgMax::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | |||
| int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); } | |||
| TypeId ArgMax::get_output_type() const { | |||
| auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | |||
| return type_ptr->type_id(); | |||
| } | |||
| AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ArgMax, prim::kPrimArgMax, ArgMaxInfer); | |||
| REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ARG_MAX_H_ | |||
| #define MINDSPORE_CORE_OPS_ARG_MAX_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameArgMax = "Argmax"; | |||
| class ArgMax : public PrimitiveC { | |||
| public: | |||
| ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); } | |||
| explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||
| ~ArgMax() = default; | |||
| MS_DECLARE_PARENT(ArgMax, PrimitiveC); | |||
| void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32); | |||
| void set_axis(const int64_t axis); | |||
| void set_output_type(const TypeId output_type); | |||
| int64_t get_axis() const; | |||
| TypeId get_output_type() const; | |||
| }; | |||
| AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimArgMaxPtr = std::shared_ptr<ArgMax>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ARG_MAX_H_ | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include "ops/arg_min.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void ArgMin::Init(const int64_t axis, const TypeId output_type) { | |||
| set_axis(axis); | |||
| set_output_type(output_type); | |||
| } | |||
| void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||
| void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } | |||
| int64_t ArgMin::get_axis() const { | |||
| auto value_ptr = GetAttr(kAxis); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| TypeId ArgMin::get_output_type() const { | |||
| auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element(); | |||
| return type_ptr->type_id(); | |||
| } | |||
| AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto argmin_prim = primitive->cast<PrimArgMin>(); | |||
| MS_EXCEPTION_IF_NULL(argmin_prim); | |||
| auto prim_name = argmin_prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name); | |||
| // Infer shape | |||
| auto axis = argmin_prim->get_axis(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto x_rank = SizeToLong(x_shape.size()); | |||
| CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); | |||
| if (axis < 0) { | |||
| axis += x_rank; | |||
| } | |||
| std::vector<int64_t> out_shape; | |||
| for (int64_t i = 0; i < x_rank; i++) { | |||
| if (i != axis) { | |||
| out_shape.push_back(x_shape[i]); | |||
| } | |||
| } | |||
| // Infer type | |||
| auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||
| CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ArgMin, prim::kPrimArgMin, ArgMinInfer); | |||
| REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,32 +14,37 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_ARG_MIN_H_ | |||
| #define MINDSPORE_CORE_OPS_ARG_MIN_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameArgMin = "ArgMin"; | |||
| class ArgMin : public PrimitiveC { | |||
| public: | |||
| ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); } | |||
| explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||
| ~ArgMin() = default; | |||
| MS_DECLARE_PARENT(ArgMin, PrimitiveC); | |||
| void Init(bool keep_dims, int64_t axis = -1); | |||
| void set_axis(int64_t axis); | |||
| void set_keep_dims(bool keep_dims); | |||
| int64_t get_axis(); | |||
| bool get_keep_dims(); | |||
| void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32); | |||
| void set_axis(const int64_t axis); | |||
| void set_output_type(const TypeId output_type); | |||
| int64_t get_axis() const; | |||
| TypeId get_output_type() const; | |||
| }; | |||
| AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimArgMin = std::shared_ptr<ArgMin>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ARGMIN_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ARG_MIN_H_ | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/asin.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto asin_prim = primitive->cast<PrimAsinPtr>(); | |||
| MS_EXCEPTION_IF_NULL(asin_prim); | |||
| auto prim_name = asin_prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); | |||
| // Infer Shape | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | |||
| // Infer Type | |||
| auto dtype = input_args[0]->BuildType(); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); | |||
| auto tensor_type = dtype->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto element = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id())); | |||
| return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Asin, prim::kPrimAsin, AsinInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAsin, Asin); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ASIN_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ASIN_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_ASIN_H_ | |||
| #define MINDSPORE_CORE_OPS_ASIN_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAsin = "Asin"; | |||
| class Asin : public PrimitiveC { | |||
| public: | |||
| @@ -29,6 +33,10 @@ class Asin : public PrimitiveC { | |||
| MS_DECLARE_PARENT(Asin, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr ASinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAsinPtr = std::shared_ptr<Asin>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ASIN_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ASIN_H_ | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <map> | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/assert.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Assert::Init(const int64_t summarize) { set_summarize(summarize); } | |||
| void Assert::set_summarize(const int64_t summarize) { this->AddAttr(kSummarize, MakeValue(summarize)); } | |||
| int64_t Assert::get_summarize() const { | |||
| auto value_ptr = GetAttr(kSummarize); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto Assert_prim = primitive->cast<PrimAssertPtr>(); | |||
| MS_EXCEPTION_IF_NULL(Assert_prim); | |||
| auto op_name = Assert_prim->name(); | |||
| TypePtr condition; | |||
| if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) { | |||
| auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue()); | |||
| CheckAndConvertUtils::CheckInteger("condition's rank", condition_value.size(), kLessEqual, 1, op_name); | |||
| if (condition_value.size() == 1) { | |||
| CheckAndConvertUtils::CheckInteger("condition[0]", condition_value[0], kEqual, 1, op_name); | |||
| } | |||
| condition = TypeIdToType(kNumberTypeBool); | |||
| } else { | |||
| auto condition_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); | |||
| CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); | |||
| if (condition_shape[0] == 1) { | |||
| auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c()); | |||
| MS_EXCEPTION_IF_NULL(condition_value); | |||
| // auto condition_value = GetValue<bool>(input_args[0]->BuildValue()); | |||
| CheckAndConvertUtils::CheckInteger("condition[0]", *condition_value, kEqual, 1, op_name); | |||
| } | |||
| condition = input_args[0]->BuildType(); | |||
| } | |||
| std::vector<int64_t> output_shape = {1}; | |||
| std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)}; | |||
| std::map<std::string, TypePtr> args = {{"condition", condition}}; | |||
| CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name); | |||
| auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements(); | |||
| for (auto dtype : inputs_type) { | |||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||
| CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name); | |||
| } | |||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assert, prim::kPrimAssert, AssertInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAssert, Assert); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ASSERT_H_ | |||
| #define MINDSPORE_CORE_OPS_ASSERT_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAssert = "Assert"; | |||
| class Assert : public PrimitiveC { | |||
| public: | |||
| Assert() : PrimitiveC(kNameAssert) {} | |||
| ~Assert() = default; | |||
| MS_DECLARE_PARENT(Assert, PrimitiveC); | |||
| void Init(const int64_t summarize = 3); | |||
| void set_summarize(const int64_t summarize); | |||
| int64_t get_summarize() const; | |||
| }; | |||
| AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAssertPtr = std::shared_ptr<Assert>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ASSERT_H_ | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ops/assign.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ir/dtype/ref.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| REGISTER_PRIMITIVE_C(kNameAssign, Assign); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ASSIGN_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ASSIGN_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_ASSIGN_H_ | |||
| #define MINDSPORE_CORE_OPS_ASSIGN_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAssign = "Assign"; | |||
| class Assign : public PrimitiveC { | |||
| public: | |||
| @@ -29,6 +33,9 @@ class Assign : public PrimitiveC { | |||
| MS_DECLARE_PARENT(Assign, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| using PrimAssignPtr = std::shared_ptr<Assign>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ASSIGN_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ASSIGN_H_ | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <map> | |||
| #include <string> | |||
| #include "ops/assign_add.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto assignadd_prim = primitive->cast<PrimAssignAddPtr>(); | |||
| MS_EXCEPTION_IF_NULL(assignadd_prim); | |||
| auto prim_name = assignadd_prim->name(); | |||
| auto value_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); | |||
| return std::make_shared<abstract::Shape>(value_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| types.emplace("w", input_args[1]->BuildType()); | |||
| // check_scalar_or_tensor_types_same | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd"); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,13 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||
| #include "c_ops/primitive_c.h" | |||
| #ifndef MINDSPORE_CORE_OPS_ASSIGN_ADD_H_ | |||
| #define MINDSPORE_CORE_OPS_ASSIGN_ADD_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAssignAdd = "AssignAdd"; | |||
| class AssignAdd : public PrimitiveC { | |||
| public: | |||
| @@ -29,6 +33,10 @@ class AssignAdd : public PrimitiveC { | |||
| MS_DECLARE_PARENT(AssignAdd, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAssignAddPtr = std::shared_ptr<AssignAdd>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ASSIGNADD_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ASSIGN_ADD_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include "ops/atan.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto atan_prim = primitive->cast<PrimAtanPtr>(); | |||
| MS_EXCEPTION_IF_NULL(atan_prim); | |||
| auto prim_name = atan_prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | |||
| // Infer Shape | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto infer_shape = std::make_shared<abstract::Shape>(x_shape); | |||
| // Infer Type | |||
| auto dtype = input_args[0]->BuildType(); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name); | |||
| auto tensor_type = dtype->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto element = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id())); | |||
| return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Atan, prim::kPrimAtan, AtanInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAtan, Atan); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,17 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ATAN_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ATAN_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_ATAN_H_ | |||
| #define MINDSPORE_CORE_OPS_ATAN_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAtan = "Atan"; | |||
| class Atan : public PrimitiveC { | |||
| public: | |||
| @@ -33,6 +34,10 @@ class Atan : public PrimitiveC { | |||
| MS_DECLARE_PARENT(Atan, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr ATanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAtanPtr = std::shared_ptr<Atan>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ATAN_H_ | |||
| #endif // MINDSPORE_CORE_OPS_ATAN_H_ | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/audio_spectrogram.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>(); | |||
| MS_EXCEPTION_IF_NULL(audio_spectrogram_prim); | |||
| auto prim_name = audio_spectrogram_prim->name(); | |||
| auto input_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); | |||
| if (input_shape.size() != 2) { | |||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | |||
| } | |||
| if (audio_spectrogram_prim->get_window_size() < 2) { | |||
| MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size(); | |||
| } | |||
| if (audio_spectrogram_prim->get_stride() < 1) { | |||
| MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride(); | |||
| } | |||
| std::vector<int64_t> infer_shape; | |||
| infer_shape.push_back(input_shape[1]); | |||
| int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size(); | |||
| infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride()); | |||
| int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size()); | |||
| infer_shape.push_back(fft_length / 2 + 1); | |||
| MS_LOG(ERROR) << infer_shape; | |||
| return std::make_shared<abstract::Shape>(infer_shape); | |||
| } | |||
| TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| auto tensor_type = infer_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto data_type = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(data_type); | |||
| return data_type; | |||
| } | |||
| } // namespace | |||
| void AudioSpectrogram::set_window_size(const int64_t window_size) { | |||
| this->AddAttr(kWindowSize, MakeValue(window_size)); | |||
| } | |||
| int64_t AudioSpectrogram::get_window_size() const { | |||
| auto value_ptr = GetAttr(kWindowSize); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void AudioSpectrogram::set_stride(const int64_t stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| int64_t AudioSpectrogram::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t AudioSpectrogram::Log2Ceil(int64_t length) { | |||
| if (length == 0) { | |||
| return -1; | |||
| } | |||
| int64_t floor = 0; | |||
| for (int64_t i = 4; i >= 0; --i) { | |||
| const int64_t shift = (int64_t)(1 << i); | |||
| int64_t tmp = length >> shift; | |||
| if (tmp != 0) { | |||
| length = tmp; | |||
| floor += shift; | |||
| } | |||
| } | |||
| return length == (length & ~(length - 1)) ? floor : floor + 1; | |||
| } | |||
| int64_t AudioSpectrogram::GetFftLength(int64_t length) { | |||
| int64_t shift = Log2Ceil(length); | |||
| return 1 << shift; | |||
| } | |||
| void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); } | |||
| bool AudioSpectrogram::get_mag_square() const { | |||
| auto value_ptr = GetAttr(kMagSquare); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, const bool mag_square) { | |||
| this->set_window_size(window_size); | |||
| this->set_stride(stride); | |||
| this->set_mag_square(mag_square); | |||
| } | |||
| AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args), | |||
| AudioSpectrogramInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AudioSpectrogram, prim::kPrimAudioSpectrogram, AudioSpectrogramInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,31 +14,38 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||
| #define MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_ | |||
| #define MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; | |||
| class AudioSpectrogram : public PrimitiveC { | |||
| public: | |||
| AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} | |||
| ~AudioSpectrogram() = default; | |||
| MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); | |||
| void Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square); | |||
| void set_window_size(const int64_t &window_size); | |||
| void set_stride(const int64_t &stride); | |||
| void set_mag_square(const bool &mag_square); | |||
| void Init(const int64_t window_size, const int64_t stride, const bool mag_square); | |||
| void set_window_size(const int64_t window_size); | |||
| void set_stride(const int64_t stride); | |||
| void set_mag_square(const bool mag_square); | |||
| int64_t get_window_size() const; | |||
| int64_t get_stride() const; | |||
| bool get_mag_square() const; | |||
| int64_t Log2Ceil(int64_t length); | |||
| int64_t GetFftLength(int64_t length); | |||
| }; | |||
| AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAudioSpectrogramPtr = std::shared_ptr<AudioSpectrogram>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_ | |||
| #endif // MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_ | |||
| @@ -14,25 +14,26 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/avg_pool.h" | |||
| #include "ops/avg_pool.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void AvgPool::set_padding(const std::string &padding) { | |||
| CheckAndConvertUtils::CheckString(kPadding, padding, {kValid, kSame}, this->name()); | |||
| this->AddAttr(kPadding, MakeValue(padding)); | |||
| namespace ops { | |||
| void AvgPool::set_pad_mode(const PadMode &pad_mode) { | |||
| int64_t swi = pad_mode; | |||
| this->AddAttr(kPadMode, MakeValue(swi)); | |||
| } | |||
| std::string AvgPool::get_padding() const { | |||
| auto value_ptr = GetAttr(kPadding); | |||
| return GetValue<std::string>(value_ptr); | |||
| PadMode AvgPool::get_pad_mode() const { | |||
| auto value_ptr = GetAttr(kPadMode); | |||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), | |||
| @@ -44,12 +45,12 @@ std::vector<int64_t> AvgPool::get_kernel_size() const { | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void AvgPool::set_strides(const std::vector<int64_t> &strides) { | |||
| this->AddAttr(kStride, | |||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, strides, this->name(), false, true))); | |||
| this->AddAttr(kStrides, | |||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); | |||
| } | |||
| std::vector<int64_t> AvgPool::get_strides() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| auto value_ptr = GetAttr(kStrides); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| @@ -70,20 +71,19 @@ std::vector<int64_t> AvgPool::get_pad() const { | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void AvgPool::set_round_mode(const int64_t &round_mode) { | |||
| CheckAndConvertUtils::CheckInRange(kRoundMode, round_mode, kIncludeBoth, {0, 1}, this->name()); | |||
| this->AddAttr(kRoundMode, MakeValue(round_mode)); | |||
| void AvgPool::set_round_mode(const RoundMode &round_mode) { | |||
| int64_t swi = round_mode; | |||
| this->AddAttr(kRoundMode, MakeValue(swi)); | |||
| } | |||
| int64_t AvgPool::get_round_mode() const { | |||
| RoundMode AvgPool::get_round_mode() const { | |||
| auto value_ptr = GetAttr(kRoundMode); | |||
| return GetValue<int64_t>(value_ptr); | |||
| return RoundMode(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, | |||
| const std::string &padding, const Format &format, const std::vector<int64_t> &pad, | |||
| const int64_t &round_mode) { | |||
| this->set_padding(padding); | |||
| void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode, | |||
| const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) { | |||
| this->set_pad_mode(pad_mode); | |||
| this->set_kernel_size(kernel_size); | |||
| this->set_strides(stride); | |||
| this->set_format(format); | |||
| @@ -98,9 +98,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||
| auto op_name = pool_prim->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| if (pool_prim->get_format() == NHWC) { | |||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | |||
| auto kernel_size = pool_prim->get_kernel_size(); | |||
| auto pad_mode = pool_prim->get_padding(); | |||
| auto pad_mode = pool_prim->get_pad_mode(); | |||
| auto batch = in_shape[0]; | |||
| auto channel = in_shape[1]; | |||
| auto in_h = in_shape[2]; | |||
| @@ -113,14 +116,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto stride_w = strides[3]; | |||
| int64_t out_h = -1; | |||
| int64_t out_w = -1; | |||
| if (pad_mode == "valid") { | |||
| if (pad_mode == VALID) { | |||
| out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | |||
| out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| } else if (pad_mode == SAME) { | |||
| out_h = ceil(in_h / stride_h); | |||
| out_w = ceil(in_w / stride_w); | |||
| } | |||
| std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; | |||
| if (pool_prim->get_format() == NHWC) { | |||
| out_shape = {batch, out_h, out_w, channel}; | |||
| } | |||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | |||
| MS_LOG(EXCEPTION) << "Kernel size is not valid."; | |||
| } | |||
| @@ -142,4 +148,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,45 +14,48 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| #define MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_AVG_POOL_H_ | |||
| #define MINDSPORE_CORE_OPS_AVG_POOL_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAvgPool = "AvgPool"; | |||
| class AvgPool : public PrimitiveC { | |||
| public: | |||
| AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); } | |||
| explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); } | |||
| ~AvgPool() = default; | |||
| MS_DECLARE_PARENT(AvgPool, PrimitiveC); | |||
| void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1}, | |||
| const std::string &padding = "valid", const Format &format = NCHW, | |||
| const std::vector<int64_t> &pad = {0, 0, 0, 0}, const int64_t &round_mode = 0); | |||
| void set_padding(const std::string &padding); | |||
| const PadMode &pad_mode = VALID, const Format &format = NCHW, | |||
| const std::vector<int64_t> &pad = {0, 0, 0, 0}, const RoundMode &round_mode = FLOOR); | |||
| void set_pad_mode(const PadMode &pad_mode); | |||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | |||
| void set_strides(const std::vector<int64_t> &strides); | |||
| void set_format(const Format &format); | |||
| void set_pad(const std::vector<int64_t> &pad); | |||
| void set_round_mode(const int64_t &round_mode); | |||
| void set_round_mode(const RoundMode &round_mode); | |||
| std::vector<int64_t> get_kernel_size() const; | |||
| std::vector<int64_t> get_strides() const; | |||
| std::string get_padding() const; | |||
| PadMode get_pad_mode() const; | |||
| Format get_format() const; | |||
| std::vector<int64_t> get_pad() const; | |||
| int64_t get_round_mode() const; | |||
| RoundMode get_round_mode() const; | |||
| }; | |||
| AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAvgPoolPtr = std::shared_ptr<AvgPool>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| #endif // MINDSPORE_CORE_OPS_AVG_POOL_H_ | |||
| @@ -0,0 +1,140 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/batch_norm.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) { | |||
| set_is_training(is_training); | |||
| set_epsilon(epsilon); | |||
| set_format(format); | |||
| set_momentum(momentum); | |||
| } | |||
| void BatchNorm::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||
| void BatchNorm::set_epsilon(const float epsilon) { | |||
| CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name()); | |||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||
| } | |||
| void BatchNorm::set_format(const Format &format) { | |||
| int64_t f = format; | |||
| this->AddAttr(kFormat, MakeValue(f)); | |||
| } | |||
| void BatchNorm::set_momentum(const float momentun) { | |||
| CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name()); | |||
| this->AddAttr(kMomentum, MakeValue(momentun)); | |||
| } | |||
| float BatchNorm::get_momentum() const { | |||
| auto value_ptr = GetAttr(kMomentum); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| bool BatchNorm::get_is_training() const { | |||
| auto value_ptr = GetAttr(kIsTraining); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| float BatchNorm::get_epsilon() const { | |||
| auto value_ptr = GetAttr(kEpsilon); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| Format BatchNorm::get_format() const { | |||
| auto value_ptr = GetAttr(kFormat); | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| // Infer shape | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto batch_prim = primitive->cast<PrimBatchNormPtr>(); | |||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||
| auto prim_name = batch_prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); | |||
| auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); | |||
| if (batch_prim->get_format() == NHWC) { | |||
| input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; | |||
| } | |||
| auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); | |||
| auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name); | |||
| auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name); | |||
| auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); | |||
| std::vector<int64_t> input_shape_norm; | |||
| if (batch_prim->get_format() == NCHW) { | |||
| input_shape_norm = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| } else { | |||
| input_shape_norm.push_back(input_x[0]); | |||
| input_shape_norm.push_back(input_x[3]); | |||
| input_shape_norm.push_back(input_x[1]); | |||
| input_shape_norm.push_back(input_x[2]); | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); | |||
| CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, | |||
| TypeError); | |||
| if (!batch_prim->get_is_training()) { | |||
| CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); | |||
| CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); | |||
| } | |||
| // Infer type | |||
| auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); | |||
| std::map<std::string, TypePtr> args; | |||
| args.emplace("scale", input_args[1]->BuildType()); | |||
| args.emplace("bias", input_args[2]->BuildType()); | |||
| CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); | |||
| std::map<std::string, TypePtr> args_moving; | |||
| args_moving.emplace("scale", input_args[2]->BuildType()); | |||
| args_moving.emplace("bias", input_args[3]->BuildType()); | |||
| CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name); | |||
| auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x); | |||
| auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | |||
| auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | |||
| auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | |||
| if (batch_prim->get_format() == NHWC) { | |||
| output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale); | |||
| output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale); | |||
| output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale); | |||
| } | |||
| AbstractBasePtrList output = {output0, output1, output2, output3, output3}; | |||
| return std::make_shared<abstract::AbstractTuple>(output); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,17 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_ | |||
| #define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "c_ops/op_utils.h" | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBatchNorm = "BatchNorm"; | |||
| class BatchNorm : public PrimitiveC { | |||
| public: | |||
| @@ -34,19 +35,23 @@ class BatchNorm : public PrimitiveC { | |||
| } | |||
| ~BatchNorm() = default; | |||
| MS_DECLARE_PARENT(BatchNorm, PrimitiveC); | |||
| void Init(bool is_training = false, float epsilon = 1e-5, const Format &format = NCHW); | |||
| void set_is_training(bool is_training); | |||
| void set_epsilon(float epsilon); | |||
| void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1, | |||
| const Format &format = NCHW); | |||
| void set_is_training(const bool is_training); | |||
| void set_epsilon(const float epsilon); | |||
| void set_format(const Format &format); | |||
| bool get_is_trainging(); | |||
| float get_epsilon(); | |||
| void set_momentum(const float momentum); | |||
| bool get_is_training() const; | |||
| float get_epsilon() const; | |||
| Format get_format() const; | |||
| float get_momentum() const; | |||
| }; | |||
| AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBatchNormPtr = std::shared_ptr<BatchNorm>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BatchNorm_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BatchNorm_H_ | |||
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <map> | |||
| #include <string> | |||
| #include "ops/batch_norm_fold.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void BatchNormFold::Init(const float momentum, const float epsilon, const bool is_training, const int64_t freeze_bn) { | |||
| set_momentum(momentum); | |||
| set_epsilon(epsilon); | |||
| set_is_training(is_training); | |||
| set_freeze_bn(freeze_bn); | |||
| } | |||
| void BatchNormFold::set_momentum(const float momentum) { | |||
| CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentum, kIncludeBoth, {0.0, 1.0}, this->name()); | |||
| this->AddAttr(kMomentum, MakeValue(momentum)); | |||
| } | |||
| float BatchNormFold::get_momentum() const { | |||
| auto value_ptr = GetAttr(kMomentum); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void BatchNormFold::set_epsilon(const float epsilon) { | |||
| float match_value = 0.0; | |||
| CheckAndConvertUtils::CheckValue(kEpsilon, epsilon, kGreaterThan, match_value, this->name()); | |||
| this->AddAttr(kEpsilon, MakeValue(epsilon)); | |||
| } | |||
| float BatchNormFold::get_epsilon() const { | |||
| auto value_ptr = GetAttr(kEpsilon); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| void BatchNormFold::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); } | |||
| bool BatchNormFold::get_is_training() const { | |||
| auto value_ptr = GetAttr(kIsTraining); | |||
| return GetValue<bool>(value_ptr); | |||
| } | |||
| void BatchNormFold::set_freeze_bn(const int64_t freeze_bn) { this->AddAttr(kFreezeBn, MakeValue(freeze_bn)); } | |||
| int64_t BatchNormFold::get_freeze_bn() const { | |||
| auto value_ptr = GetAttr(kFreezeBn); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>(); | |||
| MS_EXCEPTION_IF_NULL(BatchNormFold_prim); | |||
| auto op_name = BatchNormFold_prim->name(); | |||
| auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); | |||
| auto variance_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||
| auto global_step_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name); | |||
| CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); | |||
| CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); | |||
| CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); | |||
| auto mean_type = input_args[1]->BuildType(); | |||
| auto variance_type = input_args[2]->BuildType(); | |||
| auto x_type = input_args[0]->BuildType(); | |||
| auto global_step_type = input_args[3]->BuildType(); | |||
| std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}}; | |||
| CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name); | |||
| auto tensor_type0 = x_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type0); | |||
| auto element0 = tensor_type0->element(); | |||
| auto tensor_type1 = mean_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type1); | |||
| auto element1 = tensor_type1->element(); | |||
| auto tensor_type2 = variance_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type2); | |||
| auto element2 = tensor_type2->element(); | |||
| CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name); | |||
| CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name); | |||
| auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape); | |||
| AbstractBasePtrList output1 = {output, output, output, output}; | |||
| return std::make_shared<abstract::AbstractTuple>(output1); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchNormFold, prim::kPrimBatchNormFold, BatchNormFoldInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_ | |||
| #define MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBatchNormFold = "BatchNormFold"; | |||
| class BatchNormFold : public PrimitiveC { | |||
| public: | |||
| BatchNormFold() : PrimitiveC(kNameBatchNormFold) { | |||
| InitIOName({"x", "mean", "variance", "global_step"}, {"batch_mean", "batch_std", "running_mean", "running_std"}); | |||
| } | |||
| ~BatchNormFold() = default; | |||
| MS_DECLARE_PARENT(BatchNormFold, PrimitiveC); | |||
| void Init(const float momentum = 0.9, const float epsilon = 1e-5, const bool is_training = true, | |||
| const int64_t freeze_bn = 0); | |||
| void set_momentum(const float momentum); | |||
| void set_epsilon(const float epsilon); | |||
| void set_is_training(const bool is_training); | |||
| void set_freeze_bn(const int64_t freeze_bn); | |||
| float get_momentum() const; | |||
| float get_epsilon() const; | |||
| bool get_is_training() const; | |||
| int64_t get_freeze_bn() const; | |||
| }; | |||
| AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBatchNormFoldPtr = std::shared_ptr<BatchNormFold>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/batch_to_space.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) { | |||
| this->set_block_size(block_size); | |||
| this->set_crops(crops); | |||
| } | |||
| void BatchToSpace::set_block_size(const std::vector<int64_t> &block_size) { | |||
| this->AddAttr(kBlockSize, MakeValue(block_size)); | |||
| } | |||
| std::vector<int64_t> BatchToSpace::get_block_size() const { | |||
| auto value_ptr = this->GetAttr(kBlockSize); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void BatchToSpace::set_crops(const std::vector<std::vector<int64_t>> &crops) { | |||
| this->AddAttr(kCrops, MakeValue(crops)); | |||
| } | |||
| std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const { | |||
| auto value_ptr = this->GetAttr(kCrops); | |||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||
| } | |||
| AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim = primitive->cast<PrimBatchToSpacePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| auto block_size = prim->get_block_size(); | |||
| auto crops = prim->get_crops(); | |||
| auto out_shape = x_shape; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| auto x_block_prod = out_shape[i + 2] * block_size[i]; | |||
| auto crops_sum = crops[i][0] + crops[i][1]; | |||
| CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name); | |||
| out_shape[i + 2] = x_block_prod - crops_sum; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])", | |||
| out_shape[0] % (block_size[0] * block_size[1]), kEqual, 0, prim_name); | |||
| out_shape[0] /= block_size[0] * block_size[1]; | |||
| auto ret = input_args[0]->Broaden(); | |||
| ret->set_shape(std::make_shared<abstract::Shape>(out_shape)); | |||
| return ret; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_ | |||
| #define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBatchToSpace = "BatchToSpace"; | |||
| class BatchToSpace : public PrimitiveC { | |||
| public: | |||
| BatchToSpace() : PrimitiveC(kNameBatchToSpace) {} | |||
| ~BatchToSpace() = default; | |||
| MS_DECLARE_PARENT(BatchToSpace, PrimitiveC); | |||
| void Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops); | |||
| void set_block_size(const std::vector<int64_t> &block_size); | |||
| void set_crops(const std::vector<std::vector<int64_t>> &crops); | |||
| std::vector<int64_t> get_block_size() const; | |||
| std::vector<std::vector<int64_t>> get_crops() const; | |||
| }; | |||
| AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBatchToSpacePtr = std::shared_ptr<BatchToSpace>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_ | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/batch_to_space_nd.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>(); | |||
| MS_EXCEPTION_IF_NULL(batch_prim); | |||
| auto prim_name = batch_prim->name(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| auto out_shape = x_shape; | |||
| int64_t block_shape_prod = 1; | |||
| int64_t offset = 2; | |||
| auto block_shape = batch_prim->get_block_shape(); | |||
| auto crops = batch_prim->get_crops(); | |||
| int64_t size = block_shape.size(); | |||
| for (int64_t i = 0; i < size; i++) { | |||
| block_shape_prod = block_shape_prod * block_shape[i]; | |||
| auto x_block_prod = out_shape[i + offset] * block_shape[i]; | |||
| auto crops_sum = crops[i][0] + crops[i][1]; | |||
| CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", crops_sum, prim_name); | |||
| out_shape[i + offset] = x_block_prod - crops_sum; | |||
| } | |||
| if (out_shape[0] % block_shape_prod != 0) { | |||
| MS_EXCEPTION(ValueError) << prim_name << " input_x dimension 0 " << out_shape[0] | |||
| << " should be divisible by block_shape_prod " << block_shape_prod; | |||
| } | |||
| out_shape[0] = int64_t(floor(out_shape[0] / block_shape_prod)); | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| return infer_type; | |||
| } | |||
| } // namespace | |||
| void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) { | |||
| CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name()); | |||
| int64_t h = crops.size(); | |||
| int64_t w = crops[0].size(); | |||
| std::vector<int64_t> temp_w = {2, 2}; | |||
| CheckAndConvertUtils::Check(kCrops, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name()); | |||
| for (int64_t i = 0; i < h; i++) { | |||
| for (int64_t j = 0; j < w; j++) { | |||
| CheckAndConvertUtils::CheckInteger(kCrops, crops[i][j], kGreaterEqual, 0, this->name()); | |||
| } | |||
| } | |||
| this->AddAttr(kCrops, MakeValue(crops)); | |||
| } | |||
| std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const { | |||
| auto value_ptr = GetAttr(kCrops); | |||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||
| } | |||
| void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) { | |||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name()); | |||
| for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) { | |||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name()); | |||
| } | |||
| this->AddAttr(kBlockShape, MakeValue(block_shape)); | |||
| } | |||
| std::vector<int64_t> BatchToSpaceND::get_block_shape() const { | |||
| auto value_ptr = GetAttr(kBlockShape); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops) { | |||
| this->set_crops(crops); | |||
| this->set_block_shape(block_shape); | |||
| } | |||
| AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpaceND, prim::kPrimBatchToSpaceND, BatchToSpaceNDInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_ | |||
| #define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBatchToSpaceND = "BatchToSpaceND"; | |||
| class BatchToSpaceND : public PrimitiveC { | |||
| public: | |||
| BatchToSpaceND() : PrimitiveC(kNameBatchToSpaceND) {} | |||
| ~BatchToSpaceND() = default; | |||
| MS_DECLARE_PARENT(BatchToSpaceND, PrimitiveC); | |||
| void Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops); | |||
| void set_crops(std::vector<std::vector<int64_t>> crops); | |||
| void set_block_shape(std::vector<int64_t> block_shape); | |||
| std::vector<int64_t> get_block_shape() const; | |||
| std::vector<std::vector<int64_t>> get_crops() const; | |||
| }; | |||
| AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBatchToSpaceNDPtr = std::shared_ptr<BatchToSpaceND>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_ | |||
| @@ -14,12 +14,15 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/bias_add.h" | |||
| #include "ops/bias_add.h" | |||
| #include <memory> | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| // Add | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void BiasAdd::set_format(const Format &format) { | |||
| int64_t f = format; | |||
| this->AddAttr(kFormat, MakeValue(f)); | |||
| @@ -29,5 +32,7 @@ Format BiasAdd::get_format() const { | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void BiasAdd::Init(const Format &format) { this->set_format(format); } | |||
| // Add | |||
| REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,17 +14,20 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BIAS_ADD_H_ | |||
| #define MINDSPORE_CORE_OPS_BIAS_ADD_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| // Add | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBiasAdd = "BiasAdd"; | |||
| class BiasAdd : public PrimitiveC { | |||
| public: | |||
| @@ -35,6 +38,7 @@ class BiasAdd : public PrimitiveC { | |||
| void set_format(const Format &format); | |||
| Format get_format() const; | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BIASADD_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BIAS_ADD_H_ | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "ops/binary_cross_entropy.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>(); | |||
| MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim); | |||
| auto prim_name = binary_cross_entropy_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); | |||
| auto weight_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); | |||
| std::vector<int64_t> infer_shape; | |||
| if (weight_shape.size() < 1) { | |||
| CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); | |||
| } | |||
| if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM && | |||
| binary_cross_entropy_prim->get_reduction() != MEAN) { | |||
| infer_shape = {x_shape.begin(), infer_shape.end()}; | |||
| } | |||
| return std::make_shared<abstract::Shape>(infer_shape); | |||
| } | |||
| TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", input_args.size(), kEqual, 3, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x_shape", input_args[0]->BuildType()); | |||
| types.emplace("y_shape", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| if (input_args[3]->BuildType() != nullptr) { | |||
| types.emplace("x_shape", input_args[0]->BuildType()); | |||
| types.emplace("weight_shape", input_args[2]->BuildType()); | |||
| infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| } | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| void BinaryCrossEntropy::set_reduction(const Reduction &reduction) { | |||
| int64_t swi = reduction; | |||
| this->AddAttr(kReduction, MakeValue(swi)); | |||
| } | |||
| Reduction BinaryCrossEntropy::get_reduction() const { | |||
| auto value_ptr = GetAttr(kReduction); | |||
| return Reduction(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| void BinaryCrossEntropy::Init(const Reduction &reduction) { this->set_reduction(reduction); } | |||
| AbstractBasePtr BinaryCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyInferType(primitive, input_args), | |||
| BinaryCrossEntroyInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BinaryCrossEntropy, prim::kPrimBinaryCrossEntropy, BinaryCrossEntropyInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropy, BinaryCrossEntropy); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,25 +14,32 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_ | |||
| #define MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBinaryCrossEntropy = "BinaryCrossEntropy"; | |||
| class BinaryCrossEntropy : public PrimitiveC { | |||
| public: | |||
| BinaryCrossEntropy() : PrimitiveC(kNameBinaryCrossEntropy) {} | |||
| ~BinaryCrossEntropy() = default; | |||
| MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC); | |||
| void Init(const std::string &reduction = "mean"); | |||
| void set_reduction(const std::string &reduction); | |||
| std::string get_reduction() const; | |||
| void Init(const Reduction &reduction = MEAN); | |||
| void set_reduction(const Reduction &reduction); | |||
| Reduction get_reduction() const; | |||
| }; | |||
| AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBinaryCrossEntropyPtr = std::shared_ptr<BinaryCrossEntropy>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_ | |||
| @@ -14,13 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/black_box.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/black_box.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void BlackBox::Init(const std::string &id, int64_t size, const std::vector<int64_t> &address) { | |||
| namespace ops { | |||
| void BlackBox::Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address) { | |||
| this->set_id(id); | |||
| this->set_size(size); | |||
| this->set_address(address); | |||
| @@ -33,7 +34,7 @@ std::string BlackBox::get_id() const { | |||
| return GetValue<std::string>(value_ptr); | |||
| } | |||
| void BlackBox::set_size(int64_t size) { this->AddAttr(kSize, MakeValue(size)); } | |||
| void BlackBox::set_size(const int64_t size) { this->AddAttr(kSize, MakeValue(size)); } | |||
| int64_t BlackBox::get_size() const { | |||
| auto value_ptr = this->GetAttr(kSize); | |||
| @@ -47,4 +48,5 @@ std::vector<int64_t> BlackBox::get_address() const { | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,25 +14,26 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BLACK_BOX_H_ | |||
| #define MINDSPORE_CORE_OPS_BLACK_BOX_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBlackBox = "BlackBox"; | |||
| class BlackBox : public PrimitiveC { | |||
| public: | |||
| BlackBox() : PrimitiveC(kNameBlackBox) {} | |||
| ~BlackBox() = default; | |||
| MS_DECLARE_PARENT(BlackBox, PrimitiveC); | |||
| void Init(const std::string &id, int64_t size, const std::vector<int64_t> &address); | |||
| void Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address); | |||
| void set_id(const std::string &id); | |||
| void set_size(int64_t size); | |||
| void set_size(const int64_t size); | |||
| void set_address(const std::vector<int64_t> &address); | |||
| std::string get_id() const; | |||
| int64_t get_size() const; | |||
| @@ -40,6 +41,7 @@ class BlackBox : public PrimitiveC { | |||
| }; | |||
| using PrimBlackBoxPtr = std::shared_ptr<BlackBox>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BLACKBOX_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BLACK_BOX_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/broadcast.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Broadcast::Init(const int64_t root_rank, const std::string &group) { | |||
| this->set_root_rank(root_rank); | |||
| this->set_group(group); | |||
| } | |||
| void Broadcast::set_root_rank(const int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); } | |||
| void Broadcast::set_group(const std::string &group) { | |||
| CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name()); | |||
| this->AddAttr(kGroup, MakeValue(group)); | |||
| } | |||
| int64_t Broadcast::get_root_rank() const { | |||
| auto value_ptr = this->GetAttr(kRootRank); | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| std::string Broadcast::get_group() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<std::string>(value_ptr); | |||
| } | |||
| AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto broadcast_prim = primitive->cast<PrimBroadcast>(); | |||
| MS_EXCEPTION_IF_NULL(broadcast_prim); | |||
| auto prim_name = broadcast_prim->name(); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| // infer shape | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| // infer type | |||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| std::vector<TypePtr> output_types; | |||
| const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| for (size_t i = 0; i < input_args.size(); i++) { | |||
| auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| output_types.push_back(out_type); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name); | |||
| } | |||
| return std::make_shared<abstract::AbstractTensor>(x_type, in_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Broadcast, prim::kPrimBroadcast, BroadcastInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,27 +14,34 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BROADCAST_H_ | |||
| #define MINDSPORE_CORE_OPS_BROADCAST_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBroadcast = "Broadcast"; | |||
| class Broadcast : public PrimitiveC { | |||
| public: | |||
| Broadcast() : PrimitiveC(kNameBroadcast) {} | |||
| ~Broadcast() = default; | |||
| MS_DECLARE_PARENT(Broadcast, PrimitiveC); | |||
| void Init(int64_t root_rank, const std::string &group = "hccl_world_group"); | |||
| void set_root_rank(int64_t root_rank); | |||
| void Init(const int64_t root_rank, const std::string &group = "hccl_world_group"); | |||
| void set_root_rank(const int64_t root_rank); | |||
| void set_group(const std::string &group); | |||
| int64_t get_root_rank(); | |||
| int64_t get_root_rank() const; | |||
| std::string get_group() const; | |||
| }; | |||
| AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBroadcast = std::shared_ptr<Broadcast>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BROADCAST_H_ | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include "ops/broadcast_to.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>(); | |||
| MS_EXCEPTION_IF_NULL(broad_cast_to); | |||
| auto prim_name = broad_cast_to->name(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto input_x = broad_cast_to->get_shape(); | |||
| int64_t outer_dim_offset = input_x.size() - x_shape.size(); | |||
| CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name); | |||
| bool flag = true; | |||
| if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) { | |||
| flag = false; | |||
| } else { | |||
| flag = true; | |||
| } | |||
| if (flag == true) { | |||
| for (int64_t i = 0; i < (int64_t)input_x.size(); i++) { | |||
| if (input_x[i] == -1) { | |||
| if (i < outer_dim_offset) { | |||
| MS_EXCEPTION(ValueError) << " -1 in init shape is in an incompatible " | |||
| "location with given input tensor, -1 index in init shape: " | |||
| << i << " but -1 can only be in index" << x_shape.size() | |||
| << "onwards for this input."; | |||
| } | |||
| input_x[i] = x_shape[i - outer_dim_offset]; | |||
| } | |||
| } | |||
| } | |||
| std::reverse(input_x.begin(), input_x.end()); | |||
| return std::make_shared<abstract::Shape>(input_x); | |||
| } | |||
| TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)}; | |||
| CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); | |||
| auto infer_dtype = input_args[0]->BuildType()->type_id(); | |||
| return TypeIdToType(infer_dtype); | |||
| } | |||
| } // namespace | |||
| void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); } | |||
| void BroadcastTo::set_shape(const std::vector<int64_t> &shape) { | |||
| CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name()); | |||
| AddAttr(kShape, MakeValue(shape)); | |||
| } | |||
| std::vector<int64_t> BroadcastTo::get_shape() const { | |||
| auto value_ptr = GetAttr(kShape); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args), | |||
| BroadcastToInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,18 +14,19 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #define MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_BROADCAST_TO_H_ | |||
| #define MINDSPORE_CORE_OPS_BROADCAST_TO_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/op_utils.h" | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameBroadcastTo = "BroadcastTo"; | |||
| class BroadcastTo : public PrimitiveC { | |||
| public: | |||
| @@ -41,6 +42,7 @@ AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const Prim | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimBroadcastToPtr = std::shared_ptr<BroadcastTo>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_ | |||
| #endif // MINDSPORE_CORE_OPS_BROADCAST_TO_H_ | |||
| @@ -14,8 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/cast.h" | |||
| #include "ops/cast.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| REGISTER_PRIMITIVE_C(kNameCast, Cast); | |||
| } | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,17 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CAST_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CAST_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CAST_H_ | |||
| #define MINDSPORE_CORE_OPS_CAST_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameCast = "Cast"; | |||
| class Cast : public PrimitiveC { | |||
| public: | |||
| @@ -35,5 +36,6 @@ class Cast : public PrimitiveC { | |||
| AbstractBasePtr CastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimCast = std::shared_ptr<Cast>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CAST_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CAST_H_ | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ops/ceil.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil"); | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); | |||
| MS_EXCEPTION_IF_NULL(infer_type); | |||
| auto tensor_type = infer_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto data_type = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(data_type); | |||
| return std::make_shared<abstract::AbstractTensor>(data_type, x_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Ceil, prim::kPrimCeil, CeilInfer); | |||
| REGISTER_PRIMITIVE_C(kNameCeil, Ceil); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,26 +14,29 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CEIL_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CEIL_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CEIL_H_ | |||
| #define MINDSPORE_CORE_OPS_CEIL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameCeil = "Ceil"; | |||
| class Ceil : public PrimitiveC { | |||
| public: | |||
| Ceil() : PrimitiveC(kNameCeil) { InitIOName({"x"}, {"y"}); } | |||
| ~Ceil() = default; | |||
| MS_DECLARE_PARENT(Ceil, PrimitiveC); | |||
| void init() {} | |||
| }; | |||
| AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimCeil = std::shared_ptr<Ceil>; | |||
| using PrimCeilPtr = std::shared_ptr<Ceil>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CEIL_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CEIL_H_ | |||
| @@ -14,12 +14,13 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/clip.h" | |||
| #include "ops/clip.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Clip::Init(const float max, const float min) { | |||
| this->set_max(max); | |||
| this->set_min(min); | |||
| @@ -39,4 +40,5 @@ float Clip::get_min() const { | |||
| return GetValue<float>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameClip, Clip); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -13,15 +13,16 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CLIP_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CLIP_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CLIP_H_ | |||
| #define MINDSPORE_CORE_OPS_CLIP_H_ | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameClip = "Clip"; | |||
| class Clip : public PrimitiveC { | |||
| public: | |||
| @@ -36,6 +37,7 @@ class Clip : public PrimitiveC { | |||
| }; | |||
| using PrimClipPtr = std::shared_ptr<Clip>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CLIP_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CLIP_H_ | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <map> | |||
| #include <string> | |||
| #include "ops/concat.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| void Concat::Init(const int64_t axis) { this->set_axis(axis); } | |||
| int64_t Concat::get_axis() const { | |||
| auto value_ptr = this->GetAttr(kAxis); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } | |||
| AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim = primitive->cast<PrimConcatPtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto prim_name = prim->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tuple); | |||
| auto elements = input_tuple->elements(); | |||
| CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); | |||
| auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(element0); | |||
| auto element0_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); | |||
| auto element0_rank = SizeToLong(element0_shape.size()); | |||
| auto axis = prim->get_axis(); | |||
| CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, | |||
| prim_name); | |||
| axis = axis < 0 ? axis + element0_rank : axis; | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("element0", element0->BuildType()); | |||
| int64_t all_shp = element0_shape[axis]; | |||
| for (size_t i = 1; i < elements.size(); ++i) { | |||
| std::string elementi = "element" + std::to_string(i); | |||
| auto elementi_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), | |||
| prim_name); | |||
| for (int64_t j = 0; j < element0_rank; ++j) { | |||
| if (j != axis && elementi_shape[j] != element0_shape[j]) { | |||
| MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element."; | |||
| } | |||
| } | |||
| all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis]; | |||
| types.emplace(elementi, elements[i]->BuildType()); | |||
| } | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name); | |||
| auto ret_shape = element0_shape; | |||
| ret_shape[axis] = all_shp; | |||
| return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type), | |||
| std::make_shared<abstract::Shape>(ret_shape)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer); | |||
| REGISTER_PRIMITIVE_C(kNameConcat, Concat); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,29 +14,31 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONCAT_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CONCAT_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CONCAT_H_ | |||
| #define MINDSPORE_CORE_OPS_CONCAT_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameConcat = "Concat"; | |||
| class Concat : public PrimitiveC { | |||
| public: | |||
| Concat() : PrimitiveC(kNameConcat) {} | |||
| ~Concat() = default; | |||
| MS_DECLARE_PARENT(Concat, PrimitiveC); | |||
| void Init(int64_t axis = 0); | |||
| void set_axis(int64_t axis); | |||
| void Init(const int64_t axis = 0); | |||
| void set_axis(const int64_t axis); | |||
| int64_t get_axis() const; | |||
| }; | |||
| AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimConcatPtr = std::shared_ptr<Concat>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CONCAT_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CONCAT_H_ | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/constant.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto x = input_args[0]->BuildShape(); | |||
| auto shape_element = x->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_element); | |||
| return shape_element; | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Constant, prim::kPrimConstant, ConstantInfer); | |||
| REGISTER_PRIMITIVE_C(kNameConstant, Constant); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_CONSTANT_H_ | |||
| #define MINDSPORE_CORE_OPS_CONSTANT_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameConstant = "Constant"; | |||
| class Constant : public PrimitiveC { | |||
| public: | |||
| Constant() : PrimitiveC(kNameConstant) {} | |||
| ~Constant() = default; | |||
| MS_DECLARE_PARENT(Constant, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimConstantPtr = std::shared_ptr<Constant>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_CONSTANT_H_ | |||
| @@ -14,12 +14,30 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/constant_of_shape.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/constant_of_shape.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); | |||
| auto input_shape = | |||
| CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape"); | |||
| return std::make_shared<abstract::Shape>(input_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto constant_prim = primitive->cast<PrimConstantOfShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(constant_prim); | |||
| auto data_type = TypeId(constant_prim->get_data_type()); | |||
| return TypeIdToType(data_type); | |||
| } | |||
| } // namespace | |||
| void ConstantOfShape::Init(int64_t data_type, const std::vector<float> &value) { | |||
| this->set_data_type(data_type); | |||
| this->set_value(value); | |||
| @@ -38,5 +56,12 @@ std::vector<float> ConstantOfShape::get_value() const { | |||
| auto value_ptr = this->GetAttr(kValue); | |||
| return GetValue<std::vector<float>>(value_ptr); | |||
| } | |||
| AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ConstantOfShape, prim::kPrimConstantOfShape, ConstantOfShapeInfer); | |||
| REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,15 +14,16 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_ | |||
| #define MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameConstantOfShape = "ConstantOfShape"; | |||
| class ConstantOfShape : public PrimitiveC { | |||
| public: | |||
| @@ -36,7 +37,10 @@ class ConstantOfShape : public PrimitiveC { | |||
| std::vector<float> get_value() const; | |||
| }; | |||
| AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimConstantOfShapePtr = std::shared_ptr<ConstantOfShape>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_ | |||
| @@ -14,19 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/control_depend.h" | |||
| #include "ops/control_depend.h" | |||
| namespace mindspore { | |||
| void ControlDepend::Init(int64_t depend_mode) { this->set_depend_mode(depend_mode); } | |||
| namespace ops { | |||
| void ControlDepend::Init(const int64_t depend_mode) { this->set_depend_mode(depend_mode); } | |||
| void ControlDepend::set_depend_mode(int64_t depend_mode) { | |||
| CheckAndConvertUtils::CheckInRange(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name()); | |||
| void ControlDepend::set_depend_mode(const int64_t depend_mode) { | |||
| CheckAndConvertUtils::CheckInRange<int64_t>(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name()); | |||
| AddAttr(kDependMode, MakeValue(depend_mode)); | |||
| } | |||
| int64_t ControlDepend::get_depend_mode() { | |||
| int64_t ControlDepend::get_depend_mode() const { | |||
| auto value_ptr = GetAttr(kDependMode); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameControlDepend, ControlDepend); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -14,29 +14,29 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_ | |||
| #define MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_ | |||
| #ifndef MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_ | |||
| #define MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "c_ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameControlDepend = "ControlDepend"; | |||
| class ControlDepend : public PrimitiveC { | |||
| public: | |||
| ControlDepend() : PrimitiveC(kNameControlDepend) {} | |||
| ~ControlDepend() = default; | |||
| MS_DECLARE_PARENT(ControlDepend, PrimitiveC); | |||
| void Init(int64_t depend_mode); | |||
| void set_depend_mode(int64_t depend_mode); | |||
| int64_t get_depend_mode(); | |||
| void Init(const int64_t depend_mode); | |||
| void set_depend_mode(const int64_t depend_mode = 0); | |||
| int64_t get_depend_mode() const; | |||
| }; | |||
| AbstractBasePtr ControlDependInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimControlDepend = std::shared_ptr<ControlDepend>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CONTROl_DEPEND_H_ | |||
| #endif // MINDSPORE_CORE_OPS_CONTROl_DEPEND_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/conv2d.h" | |||
| #include "ops/conv2d.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| @@ -23,102 +23,29 @@ | |||
| #include "ir/dtype/tensor_type.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "ops/control_depend.h" | |||
| namespace mindspore { | |||
| Conv2D::Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); } | |||
| void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, | |||
| const std::string &pad_mode, const std::vector<int64_t> &pad, const std::vector<int64_t> &stride, | |||
| const std::vector<int64_t> &dilation, int64_t group) { | |||
| auto prim_name = this->name(); | |||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||
| this->AddAttr("offset_a", MakeValue(static_cast<int64_t>(0))); | |||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||
| this->set_stride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); | |||
| this->set_dilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); | |||
| this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name); | |||
| if (pad_mode == "pad") { | |||
| for (auto item : pad) { | |||
| CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name); | |||
| } | |||
| } else { | |||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | |||
| } | |||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); | |||
| this->set_out_channel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||
| this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_kernel_size() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_dilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::string Conv2D::get_pad_mode() const { | |||
| auto value_ptr = this->GetAttr(kPadMode); | |||
| return GetValue<string>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_pad() const { | |||
| auto value_ptr = this->GetAttr(kPad); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_pad_list() const { | |||
| auto value_ptr = this->GetAttr(kPadList); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| int64_t Conv2D::get_mode() const { | |||
| auto value_ptr = this->GetAttr(kMode); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t Conv2D::get_group() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t Conv2D::get_output_channel() const { | |||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||
| } | |||
| void Conv2D::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void Conv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void Conv2D::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void Conv2D::set_mode(int64_t mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void Conv2D::set_group(int64_t group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void Conv2D::set_out_channel(int64_t output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void Conv2D::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||
| if (conv_prim->get_format() == NHWC) { | |||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | |||
| w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->get_output_channel(); | |||
| auto out_channel = conv_prim->get_out_channel(); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||
| std::vector<int64_t> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| @@ -137,10 +64,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| int64_t w_out = -1; | |||
| std::vector<int64_t> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->get_pad_mode(); | |||
| if (pad_mode == "valid") { | |||
| if (pad_mode == VALID) { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| } else if (pad_mode == SAME) { | |||
| h_out = ceil(x_shape[2] / stride_h); | |||
| w_out = ceil(x_shape[3] / stride_w); | |||
| @@ -153,7 +80,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| auto pad_left = floor(pad_needed_w / 2); | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == "pad") { | |||
| } else if (pad_mode == PAD) { | |||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->get_pad()[0]; | |||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||
| @@ -165,13 +92,17 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->set_pad_list(pad_list); | |||
| conv_prim->set_pad(pad_list); | |||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| if (conv_prim->get_format() == NHWC) { | |||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | |||
| } | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||
| CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -186,12 +117,121 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase | |||
| } | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| } // namespace | |||
| void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode, | |||
| const std::vector<int64_t> &pad, const std::vector<int64_t> &stride, | |||
| const std::vector<int64_t> &dilation, int64_t group, const Format &format) { | |||
| set_kernel_size(kernel_size); | |||
| set_stride(stride); | |||
| set_dilation(dilation); | |||
| set_pad(pad); | |||
| set_pad_mode(pad_mode); | |||
| set_mode(mode); | |||
| set_out_channel(out_channel); | |||
| set_group(group); | |||
| set_format(format); | |||
| } | |||
| void Conv2D::set_out_channel(int64_t out_channel) { | |||
| AddAttr(kOutChannel, | |||
| MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name()))); | |||
| } | |||
| void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name()))); | |||
| } | |||
| void Conv2D::set_stride(const std::vector<int64_t> &stride) { | |||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); | |||
| } | |||
| void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { | |||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); | |||
| } | |||
| void Conv2D::set_pad_mode(const PadMode &pad_mode) { | |||
| std::vector<int64_t> pad = get_pad(); | |||
| if (pad_mode == PAD) { | |||
| for (auto item : pad) { | |||
| CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name()); | |||
| } | |||
| } else { | |||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name()); | |||
| } | |||
| int64_t swi = pad_mode; | |||
| AddAttr(kPadMode, MakeValue(swi)); | |||
| } | |||
| void Conv2D::set_pad(const std::vector<int64_t> &pad) { | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | |||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); | |||
| } | |||
| void Conv2D::set_mode(int64_t mode) { | |||
| AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name()))); | |||
| } | |||
| void Conv2D::set_group(int64_t group) { | |||
| AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name()))); | |||
| } | |||
| void Conv2D::set_format(const Format &format) { | |||
| int64_t f = format; | |||
| AddAttr(kFormat, MakeValue(f)); | |||
| } | |||
| int64_t Conv2D::get_out_channel() const { | |||
| auto value_ptr = GetAttr(kOutChannel); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_kernel_size() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_dilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| PadMode Conv2D::get_pad_mode() const { | |||
| auto value_ptr = GetAttr(kPadMode); | |||
| return PadMode(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| std::vector<int64_t> Conv2D::get_pad() const { | |||
| auto value_ptr = GetAttr(kPad); | |||
| return GetValue<std::vector<int64_t>>(value_ptr); | |||
| } | |||
| int64_t Conv2D::get_mode() const { | |||
| auto value_ptr = GetAttr(kMode); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| int64_t Conv2D::get_group() const { | |||
| auto value_ptr = GetAttr(kGroup); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| Format Conv2D::get_format() const { | |||
| auto value_ptr = GetAttr(kFormat); | |||
| return Format(GetValue<int64_t>(value_ptr)); | |||
| } | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args), | |||
| Conv2dInferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer); | |||
| REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||