|
|
@@ -530,7 +530,7 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr & |
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
std::set<std::string> available_data_format{"NCHW", "NHWC"}; |
|
|
std::set<std::string> available_data_format{"NCHW", "NHWC"}; |
|
|
auto data_format_ptr = primitive->GetAttr("data_format"); |
|
|
|
|
|
|
|
|
auto data_format_ptr = primitive->GetAttr("format"); |
|
|
std::string data_format = "NCHW"; |
|
|
std::string data_format = "NCHW"; |
|
|
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { |
|
|
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { |
|
|
data_format = data_format_ptr->cast<StringImmPtr>()->value(); |
|
|
data_format = data_format_ptr->cast<StringImmPtr>()->value(); |
|
|
|