|
|
|
@@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit |
|
|
|
auto prim_name = primitive->name(); |
|
|
|
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); |
|
|
|
|
|
|
|
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); |
|
|
|
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; |
|
|
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); |
|
|
|
if (format == NHWC) { |
|
|
|
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; |
|
|
|
} |
|
|
|
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); |
|
|
|
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name); |
|
|
|
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name); |
|
|
|
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); |
|
|
|
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; |
|
|
|
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; |
|
|
|
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; |
|
|
|
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; |
|
|
|
|
|
|
|
std::vector<int64_t> input_shape_norm; |
|
|
|
if (format == NCHW) { |
|
|
|
input_shape_norm = |
|
|
|
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); |
|
|
|
input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; |
|
|
|
} else { |
|
|
|
input_shape_norm.push_back(input_x[0]); |
|
|
|
input_shape_norm.push_back(input_x[3]); |
|
|
|
|