| @@ -278,6 +278,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| MS_LOG(ERROR) << "Add fusion pass failed."; | |||
| return nullptr; | |||
| } | |||
| status = AddGraphPass(optimizer, config); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Add graph pass failed."; | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_deconv_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF DeConvParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::DeConv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->group = 1; | |||
| attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); | |||
| std::vector<int64_t> dilations(2); | |||
| auto status = ParseDilations(tf_op, attr->format, &dilations); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| attr->dilateH = dilations[0]; | |||
| attr->dilateW = dilations[1]; | |||
| std::vector<int64_t> strides(2); | |||
| status = ParseStrides(tf_op, attr->format, &strides); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| attr->strideH = strides[0]; | |||
| attr->strideW = strides[1]; | |||
| auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (weight_node != nullptr) { | |||
| std::vector<int64_t> kernels(4); | |||
| status = ParseKernels(*weight_node, attr->format, &kernels); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| attr->kernelH = kernels[0]; | |||
| attr->kernelW = kernels[1]; | |||
| attr->channelIn = kernels[2]; | |||
| attr->channelOut = kernels[3]; | |||
| } else { | |||
| attr->kernelH = -1; | |||
| attr->kernelW = -1; | |||
| attr->channelIn = -1; | |||
| attr->channelOut = -1; | |||
| MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; | |||
| } | |||
| status = ParsePadMode(tf_op, &attr->padMode); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| status = AddOpInput(tf_op, 2, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); // weights | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tf_deconv_parser("Conv2DBackpropInput", new TFDeconvParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_conv_base_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFDeconvParser : public TFConvBaseParser { | |||
| public: | |||
| TFDeconvParser() = default; | |||
| ~TFDeconvParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||
| @@ -221,16 +221,22 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa | |||
| delete[] tmp_weight_data; | |||
| return; | |||
| } | |||
| auto group = primc->GetGroup(); | |||
| auto cin_group = weight_tensor->tensor_shape()[0] / group; | |||
| int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; | |||
| int cout_size = kernel_num * area_size; | |||
| for (int k = 0; k < cin_group; ++k) { | |||
| for (int i = 0; i < kernel_num; ++i) { | |||
| auto row_addr = weight_data + k * cout_size + i * area_size; | |||
| auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; | |||
| for (int j = 0; j < area_size; j++) { | |||
| new_row_addr[j] = row_addr[j] * trans_scale[i]; | |||
| if (this->fmk_type_ == lite::converter::FmkType_TF) { | |||
| for (int i = 0; i < weight_shape_size; i++) { | |||
| tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num]; | |||
| } | |||
| } else { | |||
| auto group = primc->GetGroup(); | |||
| auto cin_group = weight_tensor->tensor_shape()[0] / group; | |||
| int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; | |||
| int cout_size = kernel_num * area_size; | |||
| for (int k = 0; k < cin_group; ++k) { | |||
| for (int i = 0; i < kernel_num; ++i) { | |||
| auto row_addr = weight_data + k * cout_size + i * area_size; | |||
| auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; | |||
| for (int j = 0; j < area_size; j++) { | |||
| new_row_addr[j] = row_addr[j] * trans_scale[i]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -193,6 +193,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTF(const AnfNodePtr &conv_node, | |||
| param_value->set_format(schema::Format::Format_HWCK); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_HWKC); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| @@ -248,7 +250,7 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| return false; | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| } | |||