diff --git a/mindspore/lite/nnacl/clip.c b/mindspore/lite/nnacl/clip.c new file mode 100644 index 0000000000..473637ab03 --- /dev/null +++ b/mindspore/lite/nnacl/clip.c @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/clip.h" +#include "nnacl/errorcode.h" + +int ClipFp32(const float *src, int length, float *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/clip.h b/mindspore/lite/nnacl/clip.h new file mode 100644 index 0000000000..9e04a7b429 --- /dev/null +++ b/mindspore/lite/nnacl/clip.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_CLIP_H_ +#define MINDSPORE_LITE_NNACL_CLIP_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/quantization/fixed_point.h" + +typedef struct ClipParameter { + OpParameter op_parameter_; + float min_val_; + float max_val_; +} ClipParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int ClipFp32(const float *src, int length, float *dst, float min_val, float max_val); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_CLIP_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 87662121f4..efb592bad5 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -128,6 +128,7 @@ #include "src/ops/lsh_projection.h" #include "src/ops/hashtable_lookup.h" #include "src/ops/skip_gram.h" +#include "src/ops/clip.h" #include "src/ops/custom_predict.h" #include "src/ops/custom_normalize.h" #include "src/ops/custom_extract_features.h" @@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new HashtableLookup(primitive); case schema::PrimitiveType_SkipGram: return new SkipGram(primitive); + case schema::PrimitiveType_Clip: + return new Clip(primitive); case schema::PrimitiveType_CustomPredict: return new CustomPredict(primitive); case schema::PrimitiveType_CustomNormalize: @@ -964,6 +967,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_SkipGram: return NewPrimitiveC(primitive); + case schema::PrimitiveType_Clip: + return NewPrimitiveC(primitive); case schema::PrimitiveType_CustomPredict: return NewPrimitiveC(primitive); case schema::PrimitiveType_CustomNormalize: diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 59983c1974..2eae2c6c0c 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -59,6 +59,7 @@ #include "src/ops/argmax.h" #include "src/ops/argmin.h" #include "src/ops/cast.h" +#include "src/ops/clip.h" #include "src/ops/reshape.h" #include "src/ops/scale.h" #include "src/ops/concat.h" @@ -139,6 +140,7 @@ #include "nnacl/fp32/topk.h" #include "nnacl/reduce_parameter.h" #include "nnacl/fp32/activation.h" +#include "nnacl/clip.h" #include "nnacl/fp32/arithmetic.h" #include "nnacl/fp32/batchnorm.h" #include "nnacl/power.h" @@ -624,6 +626,20 @@ OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *prim return reinterpret_cast(act_param); } +OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) { + ClipParameter *act_param = reinterpret_cast(malloc(sizeof(ClipParameter))); + if (act_param == nullptr) { + MS_LOG(ERROR) << "malloc ClipParameter failed."; + return nullptr; + } + memset(act_param, 0, sizeof(ClipParameter)); + act_param->op_parameter_.type_ = primitive->Type(); + auto activation = reinterpret_cast(const_cast(primitive)); + act_param->min_val_ = activation->GetMin(); + act_param->max_val_ = activation->GetMax(); + return reinterpret_cast(act_param); +} + OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive) { BatchNormParameter *batch_norm_param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); if (batch_norm_param == nullptr) { @@ -1645,6 +1661,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter; populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; + populate_parameter_funcs_[schema::PrimitiveType_Clip] = PopulateClipParameter; populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter; populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter; populate_parameter_funcs_[schema::PrimitiveType_Mean] = PopulateMeanParameter; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f8bb8555bd..863f61e862 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -185,6 +185,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc + ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ) endif() diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 4322071095..bb40b536ea 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "tools/anf_importer/import_from_meta_graphT.h" #include #include #include "schema/inner/model_generated.h" #include "frontend/operator/ops.h" #include "src/param_value_lite.h" -#include "tools/anf_importer/import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "tools/common/tensor_util.h" diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index faddfef30a..c61e23b38a 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -62,6 +62,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/quant_dtype_cast_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc + ../optimizer/graph/clip_convert_activation_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc ) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 1438a90468..487b514730 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -27,6 +27,7 @@ #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" +#include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" @@ -45,6 +46,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver auto optimizer = std::make_shared(); auto pm = std::make_shared("anf fusion pass manager", false); auto graph_pm = std::make_shared("anf graph pass manager", true); + auto convert_pm = std::make_shared("anf graph convert pass manager", true); // for now - trainning is not supporting fuse operations if (config != nullptr && config->trainModel == false) { @@ -79,6 +81,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver pm->AddPass(remove_unused_cast_pass); } pm->AddPass(std::make_shared()); + convert_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(convert_pm); optimizer->AddPassManager(pm); optimizer->AddPassManager(graph_pm); auto new_graph = optimizer->Optimize(old_graph); diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 9a6b0d4451..1ff58281a6 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -32,6 +32,7 @@ #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" +#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/quantizer/aware_quantizer.h" @@ -62,6 +63,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { { Optimizer unusedOpRemoveOptimizer; unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); status = unusedOpRemoveOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc index 1315f83d4c..f4ec0bb792 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc @@ -118,6 +118,17 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s std::vector nhwc2nchwPerm = {0, 3, 1, 2}; if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { + if (formatTransPath->nodeIdx < transposePath->nodeIdx) { + if (graph->allTensors.at(formatTransNode->inputIndex[0])->format != + graph->allTensors.at(transposeNode->outputIndex[0])->format) { + return RET_OK; + } + } else { + if (graph->allTensors.at(transposeNode->inputIndex[0])->format != + graph->allTensors.at(formatTransNode->outputIndex[0])->format) { + return RET_OK; + } + } auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); if (status != RET_OK) { MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index 68f4f1d782..b5a2895590 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -6,6 +6,7 @@ file(GLOB GRAPH_PASS ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/dropout_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc new file mode 100644 index 0000000000..6b6ec109b7 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" +#include +#include "src/common/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { + +STATUS IsolateDropoutNode(schema::MetaGraphT *graphT, size_t nodeIdx) { + MS_ASSERT(graphT != nullptr); + if (graphT->nodes.size() <= nodeIdx) { + MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx; + return RET_PARAM_INVALID; + } + + CNodeT *node = graphT->nodes.at(nodeIdx).get(); + auto inputTensorIdxes = node->inputIndex; + auto outputTensorIdxes = node->outputIndex; + auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); + if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 2) { + MS_LOG(ERROR) << "Only support node who has no more than one input and two output"; + return RET_ERROR; + } + if (inputTensorIdxes.empty()) { + MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor"; + return RET_ERROR; + } + if (outputTensorIdxes.size() == 2) { + auto outDataTensorIdx = outputTensorIdxes.at(1); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + MS_LOG(ERROR) << "Unsupported Dropout: " << node->name.c_str() << " with mask output."; + return RET_ERROR; + } + } + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 1); + if (postNodeIdxes.size() != 0) { + MS_LOG(ERROR) << "Unsupported Dropout: " << node->name.c_str() << " with mask output."; + return RET_ERROR; + } + } + auto inDataTensorIdx = inputTensorIdxes.front(); + if (!outputTensorIdxes.empty()) { + auto outDataTensorIdx = outputTensorIdxes.front(); + MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); + MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + // find poseNode + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + } + } + + // now all node's outputTensors are useless + // remove all node's outputTensors + auto status = RemoveTensor(graphT, outputTensorIdxes); + if (status != RET_OK) { + MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed"; + return RET_ERROR; + } + + node->inputIndex.clear(); + node->outputIndex.clear(); + return RET_OK; +} + +STATUS DropoutNodeRemovePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + bool ifChanged = false; + for (size_t i = 0; i < graph->nodes.size(); i++) { + auto &node = graph->nodes.at(i); + if (node->primitive->value.type == schema::PrimitiveType_Dropout) { + ifChanged = true; + auto status = IsolateDropoutNode(graph, i); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateDropoutNode failed, subGraph: " << graph->name << ", node: " << node->name + << ", error: " << status; + return status; + } + } + } + return ifChanged ? RET_OK : RET_NO_CHANGE; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h new file mode 100644 index 0000000000..1006951070 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_ + +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class DropoutNodeRemovePass : public GraphPass { + public: + DropoutNodeRemovePass() = default; + + ~DropoutNodeRemovePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_ diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h index 3539568b3a..647ea39e49 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H -#define MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H #include #include "tools/converter/optimizer.h" @@ -33,4 +33,4 @@ class UnusedNodeRemovePass : public GraphPass { } // namespace lite } // namespace mindspore -#endif // MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc index 398b5af837..ef1cea6e55 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -40,26 +40,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod min = onnx_node_attr.f(); } } - if (min == 0 && max == 6) { - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - attr->type = schema::ActivationType_RELU6; - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - } else { - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - attr->max = max; - attr->min = min; - op->primitive->value.type = schema::PrimitiveType_Clip; - op->primitive->value.value = attr.release(); + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; } + attr->max = max; + attr->min = min; + op->primitive->value.type = schema::PrimitiveType_Clip; + op->primitive->value.value = attr.release(); + return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc new file mode 100644 index 0000000000..786710ab30 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/clip_convert_activation_pass.h" +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "src/ops/primitive_c.h" +#include "schema/inner/model_generated.h" +#include "src/tensor.h" +#include "tools/converter/quantizer/quant_cast.h" +#include "src/common/log_adapter.h" +#include "securec/include/securec.h" + +using mindspore::lite::PrimitiveC; +namespace mindspore::opt { +namespace { +constexpr size_t kClipMinIndex = 2; +constexpr size_t kClipMaxIndex = 3; +} // namespace + +bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr); + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + if (opt::GetCNodeType(node) != schema::PrimitiveType_Clip) { + continue; + } + auto clip_cnode = node->cast(); + MS_ASSERT(clip_cnode->inputs().size() > kClipMinIndex); + MS_ASSERT(clip_cnode->inputs().size() > kClipMaxIndex); + + auto primitive_c = GetValueNode>(clip_cnode->input(0)); + auto primT = primitive_c->GetPrimitiveT(); + float max = primT->value.AsClip()->max; + float min = primT->value.AsClip()->min; + if ((min == -1) && (max == -1)) { + if (clip_cnode->size() != 4) { + MS_LOG(ERROR) << "Clip param invalid"; + return false; + } + auto min_param_value = GetLiteParamValue(clip_cnode->input(kClipMinIndex)); + auto max_param_value = GetLiteParamValue(clip_cnode->input(kClipMaxIndex)); + if ((min_param_value->tensor_type() != mindspore::kNumberTypeFloat32) || + (max_param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { + MS_LOG(ERROR) << "Clip param type invalid"; + return false; + } + min = *reinterpret_cast(min_param_value->tensor_addr()); + max = *reinterpret_cast(max_param_value->tensor_addr()); + } + auto manager = graph->manager(); + + // relu node + auto primitive = std::make_unique(); + MS_ASSERT(primitive != nullptr); + primitive->value.type = schema::PrimitiveType_Activation; + auto prim2 = new schema::ActivationT; + MS_ASSERT(prim2 != nullptr); + if (min == 0 && max == 6) { + prim2->type = schema::ActivationType_RELU6; + } else { + prim2->type = schema::ActivationType_HARD_TANH; + prim2->min_val = min; + prim2->max_val = max; + } + primitive->value.value = prim2; + auto primitiveCValue = PrimitiveC::Create(primitive.release()); + MS_ASSERT(primitiveCValue != nullptr); + auto value_node = NewValueNode(std::shared_ptr(primitiveCValue)); + std::vector op_inputs = {value_node}; + op_inputs.push_back(clip_cnode->input(1)); + auto new_cnode = graph->NewCNode(op_inputs); + new_cnode->set_fullname_with_scope(node->fullname_with_scope()); + manager->Replace(node, new_cnode); + } + return false; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h new file mode 100644 index 0000000000..3886743ac5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_ +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::QuantType; +namespace mindspore::opt { +class ClipConvertActivationPass : public Pass { + public: + ClipConvertActivationPass() : Pass("clip_convert_activation_pass") {} + ~ClipConvertActivationPass() override = default; + // void SetQuantType(QuantType type); + // void SetFmkType(FmkType fmkType); + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h index f9d698dd7e..16ab442b1f 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ -#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ #include #include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" @@ -44,4 +44,4 @@ class WeightFormatHardCodePass : public Pass { FmkType fmk_type = lite::converter::FmkType_TF; }; } // namespace mindspore::opt -#endif // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_