Browse Source

Add dropout pass, add clip pass, and fix the fusion problem of transpose

tags/v1.1.0
gongdaguo 5 years ago
parent
commit
93a521d662
18 changed files with 410 additions and 26 deletions
  1. +29
    -0
      mindspore/lite/nnacl/clip.c
  2. +36
    -0
      mindspore/lite/nnacl/clip.h
  3. +5
    -0
      mindspore/lite/src/ops/primitive_c.cc
  4. +17
    -0
      mindspore/lite/src/populate_parameter.cc
  5. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  6. +1
    -1
      mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc
  7. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  8. +4
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  9. +2
    -0
      mindspore/lite/tools/converter/graphdef_transform.cc
  10. +11
    -0
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc
  11. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  12. +119
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
  13. +36
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h
  14. +3
    -3
      mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h
  15. +10
    -19
      mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc
  16. +94
    -0
      mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc
  17. +37
    -0
      mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h
  18. +3
    -3
      mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h

+ 29
- 0
mindspore/lite/nnacl/clip.c View File

@@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnacl/clip.h"
#include "nnacl/errorcode.h"

int ClipFp32(const float *src, int length, float *dst, float min_val, float max_val) {
if (max_val <= min_val) {
return NNACL_ERR;
}
int i = 0;
for (; i < length; ++i) {
dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]);
}
return NNACL_OK;
}

+ 36
- 0
mindspore/lite/nnacl/clip.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_CLIP_H_
#define MINDSPORE_LITE_NNACL_CLIP_H_

#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/fixed_point.h"

typedef struct ClipParameter {
OpParameter op_parameter_;
float min_val_;
float max_val_;
} ClipParameter;

#ifdef __cplusplus
extern "C" {
#endif
int ClipFp32(const float *src, int length, float *dst, float min_val, float max_val);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_CLIP_H_

+ 5
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -128,6 +128,7 @@
#include "src/ops/lsh_projection.h"
#include "src/ops/hashtable_lookup.h"
#include "src/ops/skip_gram.h"
#include "src/ops/clip.h"
#include "src/ops/custom_predict.h"
#include "src/ops/custom_normalize.h"
#include "src/ops/custom_extract_features.h"
@@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new HashtableLookup(primitive);
case schema::PrimitiveType_SkipGram:
return new SkipGram(primitive);
case schema::PrimitiveType_Clip:
return new Clip(primitive);
case schema::PrimitiveType_CustomPredict:
return new CustomPredict(primitive);
case schema::PrimitiveType_CustomNormalize:
@@ -964,6 +967,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<HashtableLookup>(primitive);
case schema::PrimitiveType_SkipGram:
return NewPrimitiveC<SkipGram>(primitive);
case schema::PrimitiveType_Clip:
return NewPrimitiveC<Clip>(primitive);
case schema::PrimitiveType_CustomPredict:
return NewPrimitiveC<CustomPredict>(primitive);
case schema::PrimitiveType_CustomNormalize:


+ 17
- 0
mindspore/lite/src/populate_parameter.cc View File

@@ -59,6 +59,7 @@
#include "src/ops/argmax.h"
#include "src/ops/argmin.h"
#include "src/ops/cast.h"
#include "src/ops/clip.h"
#include "src/ops/reshape.h"
#include "src/ops/scale.h"
#include "src/ops/concat.h"
@@ -139,6 +140,7 @@
#include "nnacl/fp32/topk.h"
#include "nnacl/reduce_parameter.h"
#include "nnacl/fp32/activation.h"
#include "nnacl/clip.h"
#include "nnacl/fp32/arithmetic.h"
#include "nnacl/fp32/batchnorm.h"
#include "nnacl/power.h"
@@ -624,6 +626,20 @@ OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *prim
return reinterpret_cast<OpParameter *>(act_param);
}

OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) {
ClipParameter *act_param = reinterpret_cast<ClipParameter *>(malloc(sizeof(ClipParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ClipParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ClipParameter));
act_param->op_parameter_.type_ = primitive->Type();
auto activation = reinterpret_cast<mindspore::lite::Clip *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->min_val_ = activation->GetMin();
act_param->max_val_ = activation->GetMax();
return reinterpret_cast<OpParameter *>(act_param);
}

OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (batch_norm_param == nullptr) {
@@ -1645,6 +1661,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter;
populate_parameter_funcs_[schema::PrimitiveType_Clip] = PopulateClipParameter;
populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter;
populate_parameter_funcs_[schema::PrimitiveType_Mean] = PopulateMeanParameter;


+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -185,6 +185,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
)
endif()


+ 1
- 1
mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc View File

@@ -14,12 +14,12 @@
* limitations under the License.
*/

#include "tools/anf_importer/import_from_meta_graphT.h"
#include <vector>
#include <algorithm>
#include "schema/inner/model_generated.h"
#include "frontend/operator/ops.h"
#include "src/param_value_lite.h"
#include "tools/anf_importer/import_from_meta_graphT.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "tools/common/tensor_util.h"


+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -62,6 +62,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
)



+ 4
- 0
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -27,6 +27,7 @@
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
@@ -45,6 +46,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);

// for now - trainning is not supporting fuse operations
if (config != nullptr && config->trainModel == false) {
@@ -79,6 +81,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
pm->AddPass(remove_unused_cast_pass);
}
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
optimizer->AddPassManager(convert_pm);
optimizer->AddPassManager(pm);
optimizer->AddPassManager(graph_pm);
auto new_graph = optimizer->Optimize(old_graph);


+ 2
- 0
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -32,6 +32,7 @@
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"

@@ -62,6 +63,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {


+ 11
- 0
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc View File

@@ -118,6 +118,17 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s
std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2};
if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) ||
(perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) {
if (formatTransPath->nodeIdx < transposePath->nodeIdx) {
if (graph->allTensors.at(formatTransNode->inputIndex[0])->format !=
graph->allTensors.at(transposeNode->outputIndex[0])->format) {
return RET_OK;
}
} else {
if (graph->allTensors.at(transposeNode->inputIndex[0])->format !=
graph->allTensors.at(formatTransNode->outputIndex[0])->format) {
return RET_OK;
}
}
auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status;


+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -6,6 +6,7 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dropout_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc


+ 119
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include <queue>
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {

STATUS IsolateDropoutNode(schema::MetaGraphT *graphT, size_t nodeIdx) {
MS_ASSERT(graphT != nullptr);
if (graphT->nodes.size() <= nodeIdx) {
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
return RET_PARAM_INVALID;
}

CNodeT *node = graphT->nodes.at(nodeIdx).get();
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 2) {
MS_LOG(ERROR) << "Only support node who has no more than one input and two output";
return RET_ERROR;
}
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
return RET_ERROR;
}
if (outputTensorIdxes.size() == 2) {
auto outDataTensorIdx = outputTensorIdxes.at(1);
auto &gOutTensorIdx = graphT->outputIndex;
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
if (*iter == outDataTensorIdx) {
MS_LOG(ERROR) << "Unsupported Dropout: " << node->name.c_str() << " with mask output.";
return RET_ERROR;
}
}
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 1);
if (postNodeIdxes.size() != 0) {
MS_LOG(ERROR) << "Unsupported Dropout: " << node->name.c_str() << " with mask output.";
return RET_ERROR;
}
}
auto inDataTensorIdx = inputTensorIdxes.front();
if (!outputTensorIdxes.empty()) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
auto &gOutTensorIdx = graphT->outputIndex;
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
}

// now all node's outputTensors are useless
// remove all node's outputTensors
auto status = RemoveTensor(graphT, outputTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
return RET_ERROR;
}

node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}

STATUS DropoutNodeRemovePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
bool ifChanged = false;
for (size_t i = 0; i < graph->nodes.size(); i++) {
auto &node = graph->nodes.at(i);
if (node->primitive->value.type == schema::PrimitiveType_Dropout) {
ifChanged = true;
auto status = IsolateDropoutNode(graph, i);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateDropoutNode failed, subGraph: " << graph->name << ", node: " << node->name
<< ", error: " << status;
return status;
}
}
}
return ifChanged ? RET_OK : RET_NO_CHANGE;
}
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_

#include <unordered_map>
#include "tools/converter/optimizer.h"

namespace mindspore {
namespace lite {
class DropoutNodeRemovePass : public GraphPass {
public:
DropoutNodeRemovePass() = default;

~DropoutNodeRemovePass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_DROPOUT_NODE_REMOVE_PASS_H_

+ 3
- 3
mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H
#define MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H

#include <unordered_map>
#include "tools/converter/optimizer.h"
@@ -33,4 +33,4 @@ class UnusedNodeRemovePass : public GraphPass {
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_UNUSED_NODE_REMOVE_PASS_H

+ 10
- 19
mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc View File

@@ -40,26 +40,17 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
min = onnx_node_attr.f();
}
}
if (min == 0 && max == 6) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} else {
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();

std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();

return RET_OK;
}



+ 94
- 0
mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc View File

@@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include <vector>
#include <memory>
#include "tools/optimizer/common/gllo_utils.h"
#include "src/ops/primitive_c.h"
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h"
#include "securec/include/securec.h"

using mindspore::lite::PrimitiveC;
namespace mindspore::opt {
namespace {
constexpr size_t kClipMinIndex = 2;
constexpr size_t kClipMaxIndex = 3;
} // namespace

bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
if (opt::GetCNodeType(node) != schema::PrimitiveType_Clip) {
continue;
}
auto clip_cnode = node->cast<CNodePtr>();
MS_ASSERT(clip_cnode->inputs().size() > kClipMinIndex);
MS_ASSERT(clip_cnode->inputs().size() > kClipMaxIndex);

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(clip_cnode->input(0));
auto primT = primitive_c->GetPrimitiveT();
float max = primT->value.AsClip()->max;
float min = primT->value.AsClip()->min;
if ((min == -1) && (max == -1)) {
if (clip_cnode->size() != 4) {
MS_LOG(ERROR) << "Clip param invalid";
return false;
}
auto min_param_value = GetLiteParamValue(clip_cnode->input(kClipMinIndex));
auto max_param_value = GetLiteParamValue(clip_cnode->input(kClipMaxIndex));
if ((min_param_value->tensor_type() != mindspore::kNumberTypeFloat32) ||
(max_param_value->tensor_type() != mindspore::kNumberTypeFloat32)) {
MS_LOG(ERROR) << "Clip param type invalid";
return false;
}
min = *reinterpret_cast<float *>(min_param_value->tensor_addr());
max = *reinterpret_cast<float *>(max_param_value->tensor_addr());
}
auto manager = graph->manager();

// relu node
auto primitive = std::make_unique<schema::PrimitiveT>();
MS_ASSERT(primitive != nullptr);
primitive->value.type = schema::PrimitiveType_Activation;
auto prim2 = new schema::ActivationT;
MS_ASSERT(prim2 != nullptr);
if (min == 0 && max == 6) {
prim2->type = schema::ActivationType_RELU6;
} else {
prim2->type = schema::ActivationType_HARD_TANH;
prim2->min_val = min;
prim2->max_val = max;
}
primitive->value.value = prim2;
auto primitiveCValue = PrimitiveC::Create(primitive.release());
MS_ASSERT(primitiveCValue != nullptr);
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveCValue));
std::vector<AnfNodePtr> op_inputs = {value_node};
op_inputs.push_back(clip_cnode->input(1));
auto new_cnode = graph->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(node->fullname_with_scope());
manager->Replace(node, new_cnode);
}
return false;
}
} // namespace mindspore::opt

+ 37
- 0
mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
#include "src/param_value_lite.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class ClipConvertActivationPass : public Pass {
public:
ClipConvertActivationPass() : Pass("clip_convert_activation_pass") {}
~ClipConvertActivationPass() override = default;
// void SetQuantType(QuantType type);
// void SetFmkType(FmkType fmkType);
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_

+ 3
- 3
mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
@@ -44,4 +44,4 @@ class WeightFormatHardCodePass : public Pass {
FmkType fmk_type = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_

Loading…
Cancel
Save