From: @liubuyu Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54tags/v1.1.0
| @@ -49,6 +49,8 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| return out->Broaden(); | return out->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // the output is useless, so we dont have to focus on the output shape | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[3]); | |||||
| auto dx = args_spec_list[1]->Broaden(); | |||||
| auto dscale = args_spec_list[2]->Broaden(); | |||||
| auto dbias = args_spec_list[3]->Broaden(); | |||||
| AbstractBasePtrList rets = {dx, dscale, dbias}; | |||||
| return std::make_shared<AbstractTuple>(rets); | |||||
| } | |||||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: three tensors(doutput, input, filters). | // Inputs: three tensors(doutput, input, filters). | ||||
| @@ -101,6 +101,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimPooling, {InferImplPooling, true}}, | {prim::kPrimPooling, {InferImplPooling, true}}, | ||||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | ||||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | ||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | ||||
| {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | ||||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | ||||
| @@ -140,6 +140,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive | |||||
| inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | ||||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | ||||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | ||||
| inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam"); | |||||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | ||||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | ||||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | ||||