|
|
|
@@ -242,8 +242,14 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv |
|
|
|
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: two tensors(y_backprop, x). |
|
|
|
CheckArgsSize(primitive->name(), args_spec_list, 2); |
|
|
|
return args_spec_list[1]->Broaden(); |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
(void)CheckDtypeSame(op_name, out, dout); |
|
|
|
(void)CheckShapeSame(op_name, out, dout); |
|
|
|
|
|
|
|
return out->Broaden(); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
|