|
|
|
@@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr |
|
|
|
return out->Broaden(); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// the output is useless, so we dont have to focus on the output shape |
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]); |
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[2]); |
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[3]); |
|
|
|
|
|
|
|
auto dx = args_spec_list[1]->Broaden(); |
|
|
|
auto dscale = args_spec_list[2]->Broaden(); |
|
|
|
auto dbias = args_spec_list[3]->Broaden(); |
|
|
|
|
|
|
|
AbstractBasePtrList rets = {dx, dscale, dbias}; |
|
|
|
return std::make_shared<AbstractTuple>(rets); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: three tensors(doutput, input, filters). |
|
|
|
|