Browse Source

!6917 [MSLITE] Remove cast operator which convert data from fp32 to fp16 in MindSpore models.

Merge pull request !6917 from wangshaocong/bugfix_master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8ac8ad5158
8 changed files with 124 additions and 18 deletions
  1. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
  2. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  3. +1
    -0
      mindspore/lite/test/models_mindspore.cfg
  4. +1
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  5. +6
    -0
      mindspore/lite/tools/converter/anf_transform.cc
  6. +1
    -14
      mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc
  7. +74
    -0
      mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc
  8. +36
    -0
      mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.h

+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc View File

@@ -36,15 +36,15 @@ int ArithmeticCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int ArithmeticCPUKernel::ReSize() {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
return ReSize();
}

int ArithmeticCPUKernel::ReSize() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();


+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -183,6 +183,7 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
)
endif()
### train


+ 1
- 0
mindspore/lite/test/models_mindspore.cfg View File

@@ -4,3 +4,4 @@ gate_u_net_small-1_110.mindir
shufflenetv2.mindir
inceptionv3.mindir
googlenet.mindir
resnext50.mindir

+ 1
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
)

add_subdirectory(../anf_importer anf_importer)


+ 6
- 0
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -27,6 +27,7 @@
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@@ -72,6 +73,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
graph_pm->AddPass(weight_format_transform_pass);
}

if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
remove_unused_cast_pass->SetFmkType(config->fmk);
pm->AddPass(remove_unused_cast_pass);
}
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
optimizer->AddPassManager(pm);
optimizer->AddPassManager(graph_pm);


+ 1
- 14
mindspore/lite/tools/converter/legacy_optimizer/graph/unused_node_remove_pass.cc View File

@@ -29,25 +29,12 @@

namespace mindspore {
namespace lite {
bool IsUnusedNode(const CNodeT &node) {
if (node.primitive->value.type == schema::PrimitiveType_TupleGetItem) {
return true;
}
if (node.primitive->value.type == schema::PrimitiveType_Cast) {
auto attr = reinterpret_cast<schema::CastT *>(node.primitive->value.value);
if (attr->srcT == kNumberTypeFloat32 && attr->dstT == kNumberTypeFloat16) {
return true;
}
}
return false;
}

STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
bool ifChanged = false;
for (size_t i = 0; i < graph->nodes.size(); i++) {
auto &node = graph->nodes.at(i);
if (IsUnusedNode(*node)) {
if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) {
ifChanged = true;
auto status = IsolateOneWayNode(graph, i);
if (status != RET_OK) {


+ 74
- 0
mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"

namespace mindspore::opt {
void RemoveUnusedCastOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }

bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != lite::converter::FmkType_MS) {
MS_LOG(ERROR) << "The framework type of model should be mindspore.";
return RET_ERROR;
}
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Cast) {
continue;
}
auto cast_cnode = node->cast<CNodePtr>();
auto abstract_base = cast_cnode->input(1)->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << cast_cnode->input(1)->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, "
<< cast_cnode->input(1)->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto input_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(input_type != nullptr);
auto input_type_value = input_type->type_id();

if (cast_cnode->inputs().size() != lite::kMultiNum || !utils::isa<ValueNodePtr>(cast_cnode->input(2))) {
MS_LOG(ERROR) << "Second input of cast should be a ValueNode";
return RET_ERROR;
}
auto output_type = GetValueNode<NumberPtr>(cast_cnode->input(2));
if (output_type == nullptr) {
MS_LOG(ERROR) << "Second input of cast is nullptr";
return RET_ERROR;
}
auto output_type_value = output_type->type_id();
if ((input_type_value == kNumberTypeFloat32 && output_type_value == kNumberTypeFloat16) ||
(input_type_value == kNumberTypeFloat16 && output_type_value == kNumberTypeFloat32)) {
manager->Replace(node, cast_cnode->input(1));
}
}
return true;
}
} // namespace mindspore::opt

+ 36
- 0
mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_
#include <string>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"

using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class RemoveUnusedCastOpPass : public Pass {
public:
RemoveUnusedCastOpPass() : Pass("remove_unused_cast_pass") {}
~RemoveUnusedCastOpPass() override = default;
void SetFmkType(FmkType fmkType);
bool Run(const FuncGraphPtr &graph) override;

private:
FmkType fmk_type = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_CAST_PASS_H_

Loading…
Cancel
Save