| @@ -278,6 +278,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||||
| MS_LOG(ERROR) << "Add fusion pass failed."; | MS_LOG(ERROR) << "Add fusion pass failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| status = AddGraphPass(optimizer, config); | status = AddGraphPass(optimizer, config); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Add graph pass failed."; | MS_LOG(ERROR) << "Add graph pass failed."; | ||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_deconv_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| #include "tools/converter/parser/tf/tf_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF DeConvParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::DeConv2DT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->group = 1; | |||||
| attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); | |||||
| std::vector<int64_t> dilations(2); | |||||
| auto status = ParseDilations(tf_op, attr->format, &dilations); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| attr->dilateH = dilations[0]; | |||||
| attr->dilateW = dilations[1]; | |||||
| std::vector<int64_t> strides(2); | |||||
| status = ParseStrides(tf_op, attr->format, &strides); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| attr->strideH = strides[0]; | |||||
| attr->strideW = strides[1]; | |||||
| auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||||
| if (weight_node != nullptr) { | |||||
| std::vector<int64_t> kernels(4); | |||||
| status = ParseKernels(*weight_node, attr->format, &kernels); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| attr->kernelH = kernels[0]; | |||||
| attr->kernelW = kernels[1]; | |||||
| attr->channelIn = kernels[2]; | |||||
| attr->channelOut = kernels[3]; | |||||
| } else { | |||||
| attr->kernelH = -1; | |||||
| attr->kernelW = -1; | |||||
| attr->channelIn = -1; | |||||
| attr->channelOut = -1; | |||||
| MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; | |||||
| } | |||||
| status = ParsePadMode(tf_op, &attr->padMode); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| status = AddOpInput(tf_op, 2, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); // weights | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tf_deconv_parser("Conv2DBackpropInput", new TFDeconvParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_conv_base_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFDeconvParser : public TFConvBaseParser { | |||||
| public: | |||||
| TFDeconvParser() = default; | |||||
| ~TFDeconvParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DECONV_PARSER_H_ | |||||
| @@ -221,16 +221,22 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa | |||||
| delete[] tmp_weight_data; | delete[] tmp_weight_data; | ||||
| return; | return; | ||||
| } | } | ||||
| auto group = primc->GetGroup(); | |||||
| auto cin_group = weight_tensor->tensor_shape()[0] / group; | |||||
| int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; | |||||
| int cout_size = kernel_num * area_size; | |||||
| for (int k = 0; k < cin_group; ++k) { | |||||
| for (int i = 0; i < kernel_num; ++i) { | |||||
| auto row_addr = weight_data + k * cout_size + i * area_size; | |||||
| auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; | |||||
| for (int j = 0; j < area_size; j++) { | |||||
| new_row_addr[j] = row_addr[j] * trans_scale[i]; | |||||
| if (this->fmk_type_ == lite::converter::FmkType_TF) { | |||||
| for (int i = 0; i < weight_shape_size; i++) { | |||||
| tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num]; | |||||
| } | |||||
| } else { | |||||
| auto group = primc->GetGroup(); | |||||
| auto cin_group = weight_tensor->tensor_shape()[0] / group; | |||||
| int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; | |||||
| int cout_size = kernel_num * area_size; | |||||
| for (int k = 0; k < cin_group; ++k) { | |||||
| for (int i = 0; i < kernel_num; ++i) { | |||||
| auto row_addr = weight_data + k * cout_size + i * area_size; | |||||
| auto new_row_addr = tmp_weight_data + k * cout_size + i * area_size; | |||||
| for (int j = 0; j < area_size; j++) { | |||||
| new_row_addr[j] = row_addr[j] * trans_scale[i]; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -193,6 +193,8 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTF(const AnfNodePtr &conv_node, | |||||
| param_value->set_format(schema::Format::Format_HWCK); | param_value->set_format(schema::Format::Format_HWCK); | ||||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| param_value->set_format(schema::Format::Format_HWKC); | param_value->set_format(schema::Format::Format_HWKC); | ||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||||
| param_value->set_format(schema::Format::Format_HWCK); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | ||||
| << ", node: " << conv_node->fullname_with_scope(); | << ", node: " << conv_node->fullname_with_scope(); | ||||
| @@ -248,7 +250,7 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (status != lite::RET_OK) { | if (status != lite::RET_OK) { | ||||
| MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope(); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||