Browse Source

!5995 Add check for relugrad

Merge pull request !5995 from riemann_penn/Add_check_for_relugrad
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e7ce5b0ae1
1 changed files with 8 additions and 2 deletions
  1. +8
    -2
      mindspore/core/abstract/prim_nn.cc

+ 8
- 2
mindspore/core/abstract/prim_nn.cc View File

@@ -242,8 +242,14 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors(y_backprop, x).
CheckArgsSize(primitive->name(), args_spec_list, 2);
return args_spec_list[1]->Broaden();
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
(void)CheckDtypeSame(op_name, out, dout);
(void)CheckShapeSame(op_name, out, dout);

return out->Broaden();
}

AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


Loading…
Cancel
Save