Browse Source

!9596 add infer function for fused sparse adam

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d88aa05859
4 changed files with 19 additions and 0 deletions
  1. +2
    -0
      mindspore/core/abstract/infer_functions.h
  2. +15
    -0
      mindspore/core/abstract/prim_nn.cc
  3. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  4. +1
    -0
      mindspore/core/base/core_ops.h

+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -49,6 +49,8 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 15
- 0
mindspore/core/abstract/prim_nn.cc View File

@@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr
return out->Broaden();
}

AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// the output is useless, so we dont have to focus on the output shape
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
MS_EXCEPTION_IF_NULL(args_spec_list[3]);

auto dx = args_spec_list[1]->Broaden();
auto dscale = args_spec_list[2]->Broaden();
auto dbias = args_spec_list[3]->Broaden();

AbstractBasePtrList rets = {dx, dscale, dbias};
return std::make_shared<AbstractTuple>(rets);
}

AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(doutput, input, filters).


+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -101,6 +101,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -140,6 +140,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");


Loading…
Cancel
Save