Browse Source

decouple anf and fb, add global format opt

pull/15419/head
xuanyue 4 years ago
parent
commit
a94efee756
28 changed files with 1064 additions and 1958 deletions
  1. +38
    -415
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  2. +6
    -25
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  3. +437
    -0
      mindspore/lite/tools/anf_exporter/fetch_content.cc
  4. +49
    -0
      mindspore/lite/tools/anf_exporter/fetch_content.h
  5. +1
    -3
      mindspore/lite/tools/converter/anf_transform.cc
  6. +0
    -6
      mindspore/lite/tools/converter/graphdef_transform.cc
  7. +0
    -4
      mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt
  8. +0
    -125
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc
  9. +0
    -49
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h
  10. +0
    -5
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  11. +0
    -461
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc
  12. +0
    -76
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h
  13. +0
    -223
      mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc
  14. +0
    -49
      mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.h
  15. +0
    -193
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
  16. +0
    -52
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h
  17. +0
    -52
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc
  18. +0
    -40
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h
  19. +25
    -22
      mindspore/lite/tools/optimizer/common/format_utils.cc
  20. +1
    -1
      mindspore/lite/tools/optimizer/common/format_utils.h
  21. +70
    -53
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
  22. +8
    -3
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h
  23. +96
    -82
      mindspore/lite/tools/optimizer/graph/node_infershape.cc
  24. +3
    -3
      mindspore/lite/tools/optimizer/graph/node_infershape.h
  25. +36
    -3
      mindspore/lite/tools/optimizer/graph/transpose_strategy.cc
  26. +1
    -0
      mindspore/lite/tools/optimizer/graph/transpose_strategy.h
  27. +290
    -13
      mindspore/lite/tools/optimizer/graph/unify_format_pass.cc
  28. +3
    -0
      mindspore/lite/tools/optimizer/graph/unify_format_pass.h

+ 38
- 415
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -34,7 +34,6 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/quantizer/bitpacking.h"
#include "src/tensor.h"
#include "src/common/utils.h"
#include "ops/op_utils.h"
#include "tools/common/graph_util.h"
@@ -80,159 +79,8 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
}
return cnodes;
}
STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
auto data_type = tensor_info->data_type();
if (data_type != kObjectTypeString) {
MS_LOG(ERROR) << "This function only used for string tensor.";
return RET_ERROR;
}
shape_vector->clear();
auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
std::string shape_str;
std::string shape_size_str;
*offset = 0;
size_t cnt = 0;
for (; *offset < tensor_info->Size(); (*offset)++) {
if (tensor_data[*offset] == ',') {
(*offset)++;
break;
}
shape_size_str.push_back(tensor_data[*offset]);
}
if (*offset == 0) {
MS_LOG(ERROR) << "string tensor's dim size not found.";
return RET_ERROR;
}
size_t shape_size = std::stoi(shape_size_str);
for (; *offset < tensor_info->Size(); (*offset)++) {
if (tensor_data[*offset] == ',') {
cnt++;
shape_vector->push_back(std::stoi(shape_str));
shape_str.clear();
} else {
shape_str.push_back(tensor_data[*offset]);
}
if (cnt == shape_size) {
(*offset)++;
break;
}
}
if (shape_vector->empty()) {
MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
return RET_ERROR;
}
return RET_OK;
}
schema::Format GetFormatByFmk(int32_t fmk_type) {
switch (fmk_type) {
case converter::FmkType_ONNX:
case lite::converter::FmkType_CAFFE:
case lite::converter::FmkType_MS:
return schema::Format_NCHW;
case lite::converter::FmkType_TF:
case lite::converter::FmkType_TFLITE:
return schema::Format_NHWC;
default:
MS_LOG(ERROR) << "don't support current fmk: " + fmk_type;
return static_cast<schema::Format>(fmk_type);
}
}

STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape_vector) {
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return RET_PARAM_INVALID;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return RET_INPUT_TENSOR_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
*data_type = typePtr->type_id();
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return RET_PARAM_INVALID;
}
*shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return RET_OK;
}
} // namespace

void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
bool has_make_tuple = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();

inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr input_node = cnode->input(i);
if (!input_node->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto make_tuple_node = utils::cast<CNodePtr>(input_node);
auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return;
}
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
has_make_tuple = true;
for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
inputs.emplace_back(make_tuple_node->input(j));
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (has_make_tuple) {
cnode->set_inputs(inputs);
}
}

void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
bool has_depend = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();

inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr inputNode = cnode->input(i);
if (!inputNode->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto depend_node = utils::cast<CNodePtr>(inputNode);
auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return;
}
if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
has_depend = true;
bool mask_out = (depend_node->inputs().size() == 3);
for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
AnfNodePtr depend_input_node = depend_node->input(j);
if (depend_input_node->isa<CNode>()) {
inputs.emplace_back(depend_input_node);
if (mask_out) {
break;
}
}
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (has_depend) {
cnode->set_inputs(inputs);
}
}

int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
@@ -653,283 +501,58 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
return RET_OK;
}

int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode,
const std::shared_ptr<PrimitiveC> &primitive_c,
int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *output_cnode) {
auto param_node = input_anode->cast<ParameterPtr>();
schema::CNodeT *op_node) {
auto param_node = cnode->input(index)->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
std::string input_name = param_node->fullname_with_scope();
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[param_node->name()]);
op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]);
return RET_OK;
}
auto schema_tensor = std::make_unique<schema::TensorT>();
schema_tensor->format = GetFormatByFmk(meta_graphT->fmkType);
if (schema_tensor->format != schema::Format_NHWC && schema_tensor->format != schema::Format_NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format;
DataInfo data_info;
if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info) !=
RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;
}

// attr weightFormat is only used by conv-like ops' second input
if (output_cnode->inputIndex.size() == 1 && primitive_c->GetAttr(opt::kWeightFormat) != nullptr) {
schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat)));
}
auto schema_tensor = std::make_unique<schema::TensorT>();
schema_tensor->format = static_cast<schema::Format>(data_info.format_);
schema_tensor->name = param_node->name();
ShapeVector shape_vector;
TypeId data_type;
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR;
}
schema_tensor->dataType = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims;
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
[](const int64_t &value) { return static_cast<int32_t>(value); });
schema_tensor->dims = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (schema_tensor->dataType == kObjectTypeTensorType && shape_vector.empty() &&
meta_graphT->fmkType == converter::FmkType_ONNX) {
schema_tensor->data.resize(0);
} else {
schema_tensor->data.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
schema_tensor->name = input_name;
QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr
? nullptr
: primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
schema_tensor->dataType == kNumberTypeInt8) {
schema_tensor->enableHuffmanCode = true;
}
schema_tensor->dims = data_info.shape_;
schema_tensor->dataType = data_info.data_type_;
schema_tensor->data = data_info.data_;
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;

node_id_map_[input_name] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK;
}

int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto valueAbstract = value_node->abstract();
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
return RET_ERROR;
}
auto typePtr = abstract_tensor->element()->GetTypeTrack();
(*schema_tensor)->dataType = typePtr->type_id();
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims;
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
[](const int64_t &value) { return static_cast<int32_t>(value); });
(*schema_tensor)->dims = dims;
if (train_flag_ && (*schema_tensor)->dims.empty()) (*schema_tensor)->dims = {1};
(*schema_tensor)->nodeType = NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();
(*schema_tensor)->data.resize(data->Size());
(*schema_tensor)->format = schema::Format_NHWC;

(*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType);
if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format;
return RET_ERROR;
}

// process weight tensor
if (data->Size() > 0) {
if (memcpy_s((*schema_tensor)->data.data(), (*schema_tensor)->data.size(), data->data_c(), data->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
}

if (primitive->GetAttr(opt::kWeightFormat) != nullptr) {
(*schema_tensor)->format = static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat)));
}
}

node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
return RET_OK;
}
int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
int ret;
// data of int64 is converted to int32 here.
(*schema_tensor)->dataType = kNumberTypeInt32;
(*schema_tensor)->dims = {1};
(*schema_tensor)->nodeType = NodeType_ValueNode;
int real_data = opt::CastToInt(value).front();
(*schema_tensor)->data.resize(sizeof(int32_t));
ret = memcpy_s((*schema_tensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
}
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
return ret;
}
void AnfExporter::ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto valueAbstract = value_node->abstract();
auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract);
auto typePtr = abstractScalar->GetTypeTrack();
(*schema_tensor)->dataType = typePtr->type_id();
(*schema_tensor)->dims = {1};
(*schema_tensor)->nodeType = NodeType_ValueNode;
auto data = value->cast<mindspore::BoolImmPtr>();
(*schema_tensor)->data.emplace_back(data->value());
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
}
int AnfExporter::ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto data = value_node->value()->cast<NumberPtr>();
schema_tensor->data.resize(sizeof(int));
int number_type = data->number_type();
if (EOK != ::memcpy_s(schema_tensor->data.data(), sizeof(int), &number_type, sizeof(int))) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_MEMORY_FAILED;
}
schema_tensor->dataType = kNumberTypeInt32;
schema_tensor->dims = {1};
schema_tensor->nodeType = NodeType_ValueNode;
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(schema_tensor);
return RET_OK;
}
void AnfExporter::ProcessInt(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
(*schema_tensor)->dataType = kNumberTypeInt32;
(*schema_tensor)->dims = {1};
(*schema_tensor)->nodeType = NodeType_ValueNode;
(*schema_tensor)->data.emplace_back(kNumberTypeInt32);
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
}
int AnfExporter::ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
int ret = RET_OK;
auto valueAbstract = value_node->abstract();
auto abstractSequnce = utils::cast<abstract::AbstractSequeuePtr>(valueAbstract);
if (abstractSequnce->isa<abstract::AbstractTuple>()) {
auto abstractTuple = utils::cast<abstract::AbstractTuplePtr>(valueAbstract);
auto x_shape_data = abstractTuple->elements();
std::vector<int32_t> shape;
for (std::size_t i = 0; i < abstractTuple->size(); ++i) {
auto value_track = x_shape_data[i]->GetValueTrack();
MS_ASSERT(value_track != nullptr);
if (value_track->isa<Int32Imm>()) {
shape.push_back((GetValue<int>(value_track)));
} else if (value_track->isa<Int64Imm>()) {
shape.push_back((GetValue<int64_t>(value_track)));
} else {
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << ".";
return RET_ERROR;
}
}
(*schema_tensor)->dataType = kNumberTypeInt32;
(*schema_tensor)->dims = {static_cast<int32_t>(shape.size())};
(*schema_tensor)->nodeType = NodeType_ValueNode;
(*schema_tensor)->data.resize(shape.size() * sizeof(int));
if (!shape.empty()) {
if (EOK != memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(),
shape.size() * sizeof(int32_t))) {
MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
return RET_MEMORY_FAILED;
}
}
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
}
return ret;
}

int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(value);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Input value is not a tensor";
return RET_INPUT_PARAM_INVALID;
}
auto ret = UpdateTensorTFromTensorInfo(tensor_info, schema_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorTFromTensorInfo failed";
return ret;
}
if (train_flag_ && (*schema_tensor)->dims.empty()) {
(*schema_tensor)->dims = {1};
}

node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(*schema_tensor));
return ret;
}

int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
const std::shared_ptr<PrimitiveC> &primitive,
int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *output_cnode) {
auto value_node = input_anode->cast<ValueNodePtr>();
auto schema_tensor = std::make_unique<schema::TensorT>();
auto value = value_node->value();
int ret = RET_OK;

if (train_flag_) {
schema_tensor->name = value_node->fullname_with_scope();
}
if (value->isa<tensor::Tensor>()) {
ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT);
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT);
} else if (value->isa<mindspore::BoolImm>()) {
ProcessBoolImm(value_node, &schema_tensor, value, output_cnode, meta_graphT);
} else if (value->isa<mindspore::Int>()) {
ProcessInt(value_node, &schema_tensor, output_cnode, meta_graphT);
} else if (value->isa<mindspore::ValueSequeue>()) {
ret = ProcessValueSequence(value_node, &schema_tensor, value, output_cnode, meta_graphT);
} else if (value->isa<Number>()) {
ret = ProcessNumber(value_node, schema_tensor.release(), output_cnode, meta_graphT);
} else if (value->isa<mindspore::tensor::Tensor>()) {
ret = ProcessTensorInfo(value_node, &schema_tensor, value, output_cnode, meta_graphT);
} else if (value->isa<FuncGraph>()) {
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph";
schema::CNodeT *op_node) {
DataInfo data_info;
auto status = FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info);
if (status == RET_NO_CHANGE) {
return RET_OK;
} else if (value->isa<Monad>()) {
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is Monad";
return RET_OK;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;
}
return ret;
if (status != RET_OK) {
MS_LOG(ERROR) << "parse value node failed.";
return status;
}
auto schema_tensor = std::make_unique<schema::TensorT>();
schema_tensor->name = cnode->input(index)->fullname_with_scope();
schema_tensor->format = static_cast<schema::Format>(data_info.format_);
schema_tensor->dataType = data_info.data_type_;
schema_tensor->dims = data_info.shape_;
schema_tensor->data = data_info.data_;
node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size();
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(schema_tensor));
return RET_OK;
}

int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
@@ -954,7 +577,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
return ret;
}
} else if (input_node->isa<Parameter>()) {
auto ret = ConvertInputParameter(input_node, primitive_c, meta_graphT, fb_node);
auto ret = ConvertInputParameter(cnode, i, primitive_c, meta_graphT, fb_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertInputParameter failed";
return ret;
@@ -963,7 +586,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
is_graph_input = true;
}
} else if (input_node->isa<ValueNode>()) {
auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node);
auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertInputValueNode failed";
return RET_ERROR;


+ 6
- 25
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -24,8 +24,10 @@
#include "schema/inner/model_generated.h"
#include "ops/primitive_c.h"
#include "ir/func_graph.h"
#include "tools/anf_exporter/fetch_content.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::ops::PrimitiveC;

@@ -44,35 +46,14 @@ class AnfExporter {
schema::CNodeT *fb_node);
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
static void RemoveIfMakeTuple(const CNodePtr &cnode);
static void RemoveIfDepend(const CNodePtr &cnode);

protected:
int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode);
int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode);
int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
void ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
void ProcessInt(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor,
const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node);
int ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node);
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node);


+ 437
- 0
mindspore/lite/tools/anf_exporter/fetch_content.cc View File

@@ -0,0 +1,437 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/anf_exporter/fetch_content.h"
#include <algorithm>
#include <string>
#include <vector>
#include <unordered_map>
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace lite {
namespace {
constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
static const std::unordered_map<int, int> TypeToTypeMap = {
{kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) {
MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr);
auto data_type = tensor_info->data_type();
if (data_type != kObjectTypeString) {
MS_LOG(ERROR) << "This function only used for string tensor.";
return RET_ERROR;
}
shape_vector->clear();
auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
std::string shape_str;
std::string shape_size_str;
*offset = 0;
size_t cnt = 0;
for (; *offset < tensor_info->Size(); (*offset)++) {
if (tensor_data[*offset] == ',') {
(*offset)++;
break;
}
shape_size_str.push_back(tensor_data[*offset]);
}
if (*offset == 0) {
MS_LOG(ERROR) << "string tensor's dim size not found.";
return RET_ERROR;
}
size_t shape_size = std::stoi(shape_size_str);
for (; *offset < tensor_info->Size(); (*offset)++) {
if (tensor_data[*offset] == ',') {
cnt++;
shape_vector->push_back(std::stoi(shape_str));
shape_str.clear();
} else {
shape_str.push_back(tensor_data[*offset]);
}
if (cnt == shape_size) {
(*offset)++;
break;
}
}
if (shape_vector->empty()) {
MS_LOG(ERROR) << "string tensor's shape shouldn't be empty.";
return RET_ERROR;
}
return RET_OK;
}
int GetFormatByFmk(int32_t fmk_type) {
switch (fmk_type) {
case converter::FmkType_ONNX:
case lite::converter::FmkType_CAFFE:
case lite::converter::FmkType_MS:
return mindspore::NCHW;
case lite::converter::FmkType_TF:
case lite::converter::FmkType_TFLITE:
return mindspore::NHWC;
default:
return -1;
}
}

STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape_vector) {
MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr);
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return RET_PARAM_INVALID;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return RET_INPUT_TENSOR_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
*data_type = typePtr->type_id();
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return RET_PARAM_INVALID;
}
*shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return RET_OK;
}

int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
bool train_flag, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto valueAbstract = value_node->abstract();
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
return RET_ERROR;
}
auto typePtr = abstract_tensor->element()->GetTypeTrack();
data_info->data_type_ = typePtr->type_id();
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (train_flag && dims.empty()) {
data_info->shape_ = {1};
}
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto data = value->cast<tensor::TensorPtr>();
data_info->data_.resize(data->Size());
data_info->format_ = GetFormatByFmk(fmk_type);
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
return RET_ERROR;
}

// process weight tensor
if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
}
return RET_OK;
}

int FetchFromInt32OrInt64ImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
// data of int64 is converted to int32 here.
data_info->data_type_ = kNumberTypeInt32;
data_info->shape_ = {1};
data_info->data_.resize(sizeof(int32_t));
auto value = value_node->value();
MS_ASSERT(value != nullptr);
int real_data = opt::CastToInt(value).front();
if (memcpy_s(data_info->data_.data(), sizeof(int32_t), &real_data, sizeof(int32_t)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_MEMORY_FAILED;
}
return RET_OK;
}

int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
data_info->data_type_ = kNumberTypeBool;
data_info->shape_ = {1};
data_info->data_.resize(sizeof(bool));
auto value = value_node->value();
MS_ASSERT(value != nullptr);
auto data = value->cast<mindspore::BoolImmPtr>();
auto data_value = data->value();
if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_MEMORY_FAILED;
}
return RET_OK;
}

int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
data_info->data_type_ = kNumberTypeInt32;
data_info->shape_ = {1};
data_info->data_.resize(sizeof(int));
auto data = value_node->value()->cast<NumberPtr>();
int number_type = data->number_type();
if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) {
number_type = TypeToTypeMap.at(number_type);
}
if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_MEMORY_FAILED;
}
return RET_OK;
}

int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto value = value_node->value();
MS_ASSERT(value != nullptr);
std::vector<int32_t> shape;
auto value_seq = value->cast<ValueSequeuePtr>();
MS_ASSERT(value_seq != nullptr);
if (!value_seq->value().empty()) {
if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 ||
value_seq->value().front()->type()->number_type() == kNumberTypeInt) {
shape = GetValue<std::vector<int>>(value);
} else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) {
auto origin_value = GetValue<std::vector<int64_t>>(value);
std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape),
[](int64_t val) { return static_cast<int32_t>(val); });
} else {
MS_LOG(ERROR) << "Value type is ValueSequence is not integer.";
return RET_ERROR;
}
}
data_info->data_type_ = kNumberTypeInt32;
data_info->shape_ = {static_cast<int32_t>(shape.size())};
data_info->data_.resize(shape.size() * sizeof(int));
if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(),
shape.size() * sizeof(int32_t)) != EOK) {
MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
return RET_ERROR;
}
return RET_OK;
}
} // namespace

int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto param_node = cnode->input(index)->cast<ParameterPtr>();
data_info->format_ = GetFormatByFmk(fmk_type);
if (data_info->format_ < 0) {
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
return lite::RET_ERROR;
}
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
return RET_ERROR;
}

// attr weightFormat is only used by conv-like ops' second input
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
}
ShapeVector shape_vector;
TypeId data_type;
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR;
}
data_info->data_type_ = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && data_type == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
QuantParamHolderPtr quant_param_holder =
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) {
data_info->enable_huffman_code_ = true;
}
data_info->node_type_ = NodeType_ValueNode;
return RET_OK;
}

int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
auto value = value_node->value();
int ret = RET_OK;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (value->isa<tensor::Tensor>()) {
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
}
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info);
} else if (value->isa<mindspore::BoolImm>()) {
ret = FetchFromBoolImmValue(value_node, prim, data_info);
} else if (value->isa<mindspore::ValueSequeue>()) {
ret = FetchFromSequenceValue(value_node, prim, data_info);
} else if (value->isa<Number>()) {
ret = FetchFromNumberValue(value_node, prim, data_info);
} else if (value->isa<FuncGraph>()) {
MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph";
return RET_NO_CHANGE;
} else if (value->isa<Monad>()) {
MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad";
return RET_NO_CHANGE;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;
}
data_info->node_type_ = NodeType_ValueNode;
return ret;
}

int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto abstract = opt::GetCNodeInputAbstract(cnode, index);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract cnode is nullptr.";
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
MS_LOG(ERROR) << "Abstract should be anstract tensor.";
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
return RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->format_ = mindspore::NHWC;
data_info->data_type_ = type_ptr->type_id();
data_info->shape_ = dims;
data_info->node_type_ = NodeType_CNode;
if (type_ptr->type_id() == kObjectTypeTensorType) {
auto tensor_info = abstract_tensor->GetValueTrack();
if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
MS_LOG(ERROR) << "tensor info is invalid.";
return RET_ERROR;
}
auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
if (tensor_value->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_value->Size());
if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) !=
EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}

void RemoveIfDepend(const CNodePtr &cnode) {
bool has_depend = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();

inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr inputNode = cnode->input(i);
if (!inputNode->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto depend_node = utils::cast<CNodePtr>(inputNode);
auto value_node = depend_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return;
}
if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
has_depend = true;
bool mask_out = (depend_node->inputs().size() == 3);
for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
AnfNodePtr depend_input_node = depend_node->input(j);
if (depend_input_node->isa<CNode>()) {
inputs.emplace_back(depend_input_node);
if (mask_out) {
break;
}
}
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (has_depend) {
cnode->set_inputs(inputs);
}
}

void RemoveIfMakeTuple(const CNodePtr &cnode) {
bool has_make_tuple = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();

inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr input_node = cnode->input(i);
if (!input_node->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto make_tuple_node = utils::cast<CNodePtr>(input_node);
auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return;
}
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
has_make_tuple = true;
for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
inputs.emplace_back(make_tuple_node->input(j));
}
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (has_make_tuple) {
cnode->set_inputs(inputs);
}
}
} // namespace lite
} // namespace mindspore

+ 49
- 0
mindspore/lite/tools/anf_exporter/fetch_content.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_
#define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_

#include <string>
#include <vector>
#include "ir/primitive.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "tools/converter/converter_flags.h"

namespace mindspore {
namespace lite {
struct DataInfo {
bool enable_huffman_code_;
int format_;
int data_type_;
int node_type_;
std::vector<int> shape_;
std::vector<uint8_t> data_;
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0) {}
};
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
void RemoveIfDepend(const CNodePtr &cnode);

void RemoveIfMakeTuple(const CNodePtr &cnode);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_

+ 1
- 3
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -168,9 +168,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) {
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init();
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
}
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
update_conv2d_param_pass->SetFmkType(config->fmk);


+ 0
- 6
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -21,15 +21,10 @@
#include "src/common/log_adapter.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
@@ -129,7 +124,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}


+ 0
- 4
mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt View File

@@ -1,12 +1,8 @@
file(GLOB FUSION_SRC
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/mul_add_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc
)
set_property(SOURCE ${FUSION_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(fusion_mid OBJECT ${FUSION_SRC})


+ 0
- 125
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc View File

@@ -1,125 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite {
#define kFormatTransMatchPathLen2 2
#define kFormatTransMatchPathLen3 3

STATUS FormatTransFusionPass::DefinePattern() {
// nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc
{
auto transpose1 = std::make_shared<PatternOp>();
transpose1->id = kFormatTransTranspose1;
transpose1->types = {PrimitiveType_Transpose};
auto transpose2 = std::make_shared<PatternOp>();
transpose2->id = kFormatTransTranspose2;
transpose2->types = {PrimitiveType_Transpose};

transpose2->left = transpose1;
auto pattern = std::make_unique<FusionPattern>(kNc2NhAndNh2NcFusionPattern);
if (pattern == nullptr) {
MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed";
return RET_ERROR;
}
pattern->AddPatternOp(transpose1);
pattern->AddPatternOp(transpose2);
pattern->Finish();
this->patterns.emplace_back(pattern.release());
}
// nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw
{
auto transpose1 = std::make_shared<PatternOp>();
transpose1->id = kFormatTransTranspose1;
transpose1->types = {PrimitiveType_Transpose};
auto passOp = std::make_shared<PatternOp>();
passOp->id = kFormatTransPassOp;
passOp->types = {PrimitiveType_QuantDTypeCast};
auto transpose2 = std::make_shared<PatternOp>();
transpose2->id = kFormatTransTranspose2;
transpose2->types = {PrimitiveType_Transpose};

passOp->left = transpose2;
transpose1->left = passOp;
auto pattern = std::make_unique<FusionPattern>(kNh2NcAndNc2NhPassFusionPattern);
if (pattern == nullptr) {
MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed";
return RET_ERROR;
}
pattern->AddPatternOp(transpose1);
pattern->AddPatternOp(passOp);
pattern->AddPatternOp(transpose2);
pattern->Finish();
this->patterns.emplace_back(pattern.release());
}
return RET_OK;
}

STATUS FormatTransFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); }

STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != kFormatTransMatchPathLen2 && matchedPath.size() != kFormatTransMatchPathLen3) {
MS_LOG(ERROR) << "schema::Format-Transform-Fusion should have " << kFormatTransMatchPathLen2 << " or "
<< kFormatTransMatchPathLen3 << " NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

std::shared_ptr<Path> srcPath = matchedPath[kFormatTransTranspose1];
std::shared_ptr<Path> dstPath = matchedPath[kFormatTransTranspose2];
if (srcPath == nullptr || dstPath == nullptr) {
MS_LOG(ERROR) << "srcPath or dstPath is failed to get";
return RET_ERROR;
}
auto &srcNode = graph->nodes.at(srcPath->nodeIdx);
auto &dstNode = graph->nodes.at(dstPath->nodeIdx);
MS_ASSERT(srcNode != nullptr);
MS_ASSERT(dstNode != nullptr);
auto src_perm = GetTransposePerm(graph, srcNode);
auto dst_perm = GetTransposePerm(graph, dstNode);
bool isNc2NhAndNh2Nc = src_perm == nchw2nhwc_perm && dst_perm == nhwc2nchw_perm;
bool isNh2NcAndNc2Nh = src_perm == nhwc2nchw_perm && dst_perm == nchw2nhwc_perm;
if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) {
auto status = IsolateOneWayNode(graph, srcPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status;
return status;
}
status = IsolateOneWayNode(graph, dstPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status;
return status;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 49
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h View File

@@ -1,49 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H
#define MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H

#include <memory>
#include <string>
#include <unordered_map>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1";
constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2";
constexpr const char *kFormatTransPassOp = "FormatTransPassOp";
constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern";
constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern";

class FormatTransFusionPass : public FusionPass {
public:
FormatTransFusionPass() = default;

~FormatTransFusionPass() override = default;

STATUS DefinePattern() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H

+ 0
- 5
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -1,17 +1,12 @@
file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/trans_format_insert_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dropout_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc


+ 0
- 461
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc View File

@@ -1,461 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include <string>
#include <memory>
#include <utility>
#include <vector>
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/common/node_util.h"
#include "src/common/log_adapter.h"
#include "src/common/common.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
#define kMinInputNum 1
#define kOutputNum 1

STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto status = DoModelInputFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status;
return status;
}
status = DoNodeInoutFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status;
return status;
}
return RET_OK;
}

STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *after_node_type) {
if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc
if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*before_node_type = kNHWC2NCHW;
*after_node_type = kNCHW2NHWC;
return RET_OK;
} else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS ||
fmk_type_ == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*before_node_type = kNCHW2NHWC;
*after_node_type = kNHWC2NCHW;
return RET_OK;
} else if (fmk_type_ == converter::FmkType_TF) {
if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) {
*before_node_type = kNCHW2NHWC;
*after_node_type = kNHWC2NCHW;
return RET_OK;
}
if (IsContain(GetNchwOpList(), GetCNodeTType(node))) {
*before_node_type = kNHWC2NCHW;
*after_node_type = kNCHW2NHWC;
return RET_OK;
}
return RET_NO_CHANGE;
}
MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_;
return RET_ERROR;
}

STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) {
return RET_OK;
}
MS_ASSERT(graph != nullptr);
// insert trans node in model input tensor
if (graph->nodes.empty()) {
return RET_OK;
}
// onnx input format may be nhwc
if (fmk_type_ == converter::FmkType_ONNX && graph->inputIndex.size() == 1) {
auto &input_tensor = graph->allTensors.at(graph->inputIndex[0]);
auto &input_dims = input_tensor->dims;
if (input_dims.size() == 4 && input_dims[3] != -1 && input_dims[1] == -1) {
return RET_OK;
}
}
auto graph_input_idxes = graph->inputIndex;
for (size_t i = 0; i < graph_input_idxes.size(); i++) {
bool transed = false;
auto input_idx = graph_input_idxes.at(i);
auto &tensor = graph->allTensors.at(input_idx);
if (tensor->dims.size() != kNCHWDimNumber) {
continue;
}

for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) {
if ((*iter)->inputIndex.at(input_index_idx) == input_idx) {
STATUS status = RET_OK;
iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
return status;
}
// set first tensor format to nhwc
auto &trans_node = *(iter - 1);
MS_ASSERT(trans_node != nullptr);
MS_ASSERT(trans_node->inputIndex.size() == 1);
auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front());
graph_in_tensor->format = schema::Format::Format_NHWC;
// assume parser not reformat shape
auto old_dims = graph_in_tensor->dims;
if (!transed) {
graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]};
transed = true;
}
}
}
}
}
return RET_OK;
}

// inference needed inputFormat:
// conv deconv depth dedepth
// fp32 NCHW NCHW NCHW NCHW
// uint8 NCHW ? NCHW ?
STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert before and after the op cal by nchw/nc4hw4
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
FormatTransNodeType before_node_type = kNCHW2NHWC;
FormatTransNodeType after_node_type = kNHWC2NCHW;
STATUS status = RET_OK;
status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type);
if (status == RET_NO_CHANGE) {
continue;
}
if (status != RET_OK) {
return status;
}
auto &node = *iter;
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
if (node->outputIndex.size() < kOutputNum) {
MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor";
return RET_ERROR;
}
void *attr = node->primitive->value.value;
if (node->primitive->value.type == schema::PrimitiveType_SpaceToDepth) {
reinterpret_cast<schema::SpaceToDepthT *>(attr)->format = schema::Format_NHWC;
}
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
}
auto spec_insert_indexes = GetExtNhwcIndexes();
auto op_type = GetCNodeTType(**iter);
if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) {
for (auto insert_index : spec_insert_indexes[op_type]) {
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR;
}
}
} else if (IsContain(GetNhwcAllInputOpList(), op_type)) {
auto input_size = node->inputIndex.size();
if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) {
if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) {
input_size = 1;
}
}
for (size_t i = 0; i < input_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR;
}
}
} else {
iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status);
}
iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR;
}
}
return RET_OK;
}

NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) {
MS_ASSERT((*exist_node_iter) != nullptr);
MS_ASSERT(graph != nullptr);
auto exist_node_name = (*exist_node_iter)->name;
std::string tile_name;
if (place == kBefore) {
tile_name = exist_node_name + "_pre";
} else {
tile_name = exist_node_name + "_post";
}
auto trans_node = std::make_unique<schema::CNodeT>();
trans_node->primitive = std::make_unique<schema::PrimitiveT>();
trans_node->primitive->value.type = schema::PrimitiveType_Transpose;
auto perm_tensor = std::make_unique<schema::TensorT>();
perm_tensor->dataType = kNumberTypeInt32;
perm_tensor->dims = {4};
std::vector<int> perm;
if (node_type == kNCHW2NHWC) {
trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++);
perm = {0, 2, 3, 1};
} else {
trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++);
perm = {0, 3, 1, 2};
}
size_t bytes = perm.size() * sizeof(int);
perm_tensor->data.resize(bytes);
if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
}
perm_tensor->name = trans_node->name + "_perm";

OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> {
auto new_op_def = std::make_unique<schema::CNodeT>();
if (new_op_def == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
new_op_def->name = in_op_def->name;
new_op_def->quantType = in_op_def->quantType;
new_op_def->primitive = std::make_unique<schema::PrimitiveT>();
if (new_op_def->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr;
}
new_op_def->primitive->value.type = schema::PrimitiveType_Transpose;
return new_op_def;
};
int insert_num = 0;
auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
transpose_op_copyer);
size_t index = graph->allTensors.size();
graph->allTensors.push_back(std::move(perm_tensor));
for (int i = insert_num; i > 0; --i) {
(*(iter - i))->inputIndex.push_back(index);
}
return iter;
}

int FormatTransPass::GetFormat(const schema::CNodeT &node) {
switch (node.primitive->value.type) {
case schema::PrimitiveType_Conv2DFusion:
return node.primitive->value.AsConv2DFusion()->format;
case schema::PrimitiveType_Conv2dTransposeFusion:
return node.primitive->value.AsConv2dTransposeFusion()->format;
case schema::PrimitiveType_AvgPoolFusion:
return node.primitive->value.AsAvgPoolFusion()->format;
case schema::PrimitiveType_MaxPoolFusion:
return node.primitive->value.AsMaxPoolFusion()->format;
default:
return schema::Format_NHWC;
}
}

STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
MS_ASSERT(node->primitive != nullptr);
auto type = node->primitive->value.type;
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
if (input1_ndim != 4) {
if (node->inputIndex.size() > 1) {
auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
if (input2_ndim != 4 && input2_ndim != 0) {
MS_LOG(ERROR) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
} else {
MS_LOG(DEBUG) << "change op axis only support 4 dims";
return RET_NOT_SUPPORT;
}
}
if (type == schema::PrimitiveType_Concat) {
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
auto origin_axis = node->primitive->value.AsConcat()->axis;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsConcat() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis];
}
if (type == schema::PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
auto origin_axis = node->primitive->value.AsSplit()->axis;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsSplit() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsSplit()->axis = axis_map[origin_axis];
}
if (type == schema::PrimitiveType_Crop) {
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
auto origin_axis = node->primitive->value.AsCrop()->axis;
auto offsets = node->primitive->value.AsCrop()->offsets;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsCrop() == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
return RET_NULL_PTR;
}
// nchw->nhwc,offsets need pad 0;
if (axis_map[origin_axis] == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
// orgin_axis = 2 or orgin_axis = 3
offsets.push_back(0);
} else if (axis_map[origin_axis] == -1) {
// origin_axis = 1
offsets = {offsets[1], offsets[2], offsets[0]};
} else {
// axis error
MS_LOG(ERROR) << "Crop error";
return RET_ERROR;
}
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
node->primitive->value.AsCrop()->offsets = offsets;
}
if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) {
return ChangeOpSliceAndStridedSlice(graph, node);
}
return RET_OK;
}

void FormatTransPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
return;
}
auto axis_map = GetNc2NhAxisMap();
std::vector<int> cur_attr;
for (int dim = 0; dim < 4; ++dim) {
for (int index = 0; index < element_size; ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
cur_attr.push_back(origin_attr[index]);
}
}
}
for (int index = 0; index < element_size; ++index) {
origin_attr[index] = cur_attr[index];
}
}

void FormatTransPass::TransformOpAxisAttr(int *origin_axis, int element_size) {
if (origin_axis == nullptr || element_size == 0) {
return;
}
auto axis_map = GetNc2NhAxisMap();
std::vector<int> new_axis;
for (int i = 0; i < element_size; ++i) {
int axis = axis_map[origin_axis[i]];
axis = axis < 0 ? axis + 4 : axis;
new_axis.push_back(axis);
}
std::sort(new_axis.begin(), new_axis.end());
for (int i = 0; i < element_size; ++i) {
origin_axis[i] = new_axis[i];
}
}

STATUS FormatTransPass::ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
auto attr = node->primitive->value.AsSliceFusion();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSliceFusion() is nullptr.";
return RET_NULL_PTR;
}
// transform attr
if (node->inputIndex.size() < 2) {
MS_LOG(ERROR) << "slice input is error";
return RET_ERROR;
}
for (size_t index = 1; index < node->inputIndex.size(); ++index) {
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
return RET_NOT_SUPPORT;
}
}
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
std::vector<int> axes;
auto axes_attr = attr->axes;
if (axes_attr.empty()) {
for (int index = 0; index < element_num; ++index) {
axes.push_back(index);
}
} else {
std::transform(axes_attr.begin(), axes_attr.end(), std::back_inserter(axes),
[](int64_t val) { return static_cast<int>(val); });
}
for (size_t index = 1; index < node->inputIndex.size(); ++index) {
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
reinterpret_cast<int *>(axes.data()), element_num);
}
TransformOpAxisAttr(axes.data(), element_num);
attr->axes.clear();
for (int i = 0; i < element_num; ++i) {
attr->axes.push_back(static_cast<int64_t>(axes[i]));
}
return RET_OK;
}

STATUS FormatTransPass::ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
// onnx input size is equal to 5 always.
if (node->inputIndex.size() != 5) {
return RET_NOT_SUPPORT;
}
if (node->inputIndex.size() == 5) {
for (int index = 1; index < 5; ++index) {
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
return RET_NOT_SUPPORT;
}
}
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
auto axes = graph->allTensors[node->inputIndex[3]]->data;
for (int index = 1; index < 5; ++index) {
if (index == 3) {
continue;
}
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
reinterpret_cast<int *>(axes.data()), element_num);
}
TransformOpAxisAttr(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[3]]->data.data()), element_num);
}
return RET_OK;
}

STATUS FormatTransPass::ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph,
const std::unique_ptr<schema::CNodeT> &node) {
auto type = node->primitive->value.type;
if (type == schema::PrimitiveType_StridedSlice) {
return ChangeOpStridedSlice(graph, node);
}
if (type == schema::PrimitiveType_SliceFusion) {
return ChangeOpSlice(graph, node);
}
return RET_ERROR;
}
} // namespace lite
} // namespace mindspore

+ 0
- 76
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H
#define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H

#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"

namespace mindspore {
namespace lite {
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };

class FormatTransPass : public GraphPass {
public:
FormatTransPass() : id_(0) {}

~FormatTransPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;

void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; }

void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }

protected:
NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code);

STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);

private:
STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph);

STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph);

void TransformAttrByAxes(int *origin_attr, int *axes, int element_size);

void TransformOpAxisAttr(int *origin_axis, int element_size);

STATUS ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);

STATUS ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);

STATUS ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);

int GetFormat(const schema::CNodeT &);

STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *after_node_type);

protected:
size_t id_ = 0;
converter::FmkType fmk_type_ = converter::FmkType_TF;

private:
QuantType quant_type_ = QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H

+ 0
- 223
mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc View File

@@ -1,223 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h"
#include <algorithm>
#include "third_party/securec/include/securec.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/node_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite {

STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::set<size_t> need_del_nodes;
std::set<size_t> need_trans_format_nodes;
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto type = node->primitive->value.type;
if (type != PrimitiveType_Transpose) {
continue;
}
if (GetTransposePerm(graph, node) != nchw2nhwc_perm) {
continue;
}
std::vector<size_t> pre_nh2nc_nodes;
std::vector<size_t> pre_not_trans_nodes;
auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
return status;
}
std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end()));
std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(),
std::inserter(need_trans_format_nodes, need_trans_format_nodes.end()));
if (!pre_nh2nc_nodes.empty()) {
need_del_nodes.insert(iter - graph->nodes.begin());
}
}
if (need_del_nodes.empty()) {
return RET_OK;
}
for (auto del_node_index : need_del_nodes) {
auto node_name = graph->nodes.at(del_node_index)->name;
auto status = IsolateOneWayNode(graph, del_node_index);
if (status != RET_OK) {
MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status;
return status;
}
}

auto status = TransWeightToNhwc(graph, need_trans_format_nodes);
if (status != RET_OK) {
MS_LOG(ERROR) << "trans weight to nhwc failed";
return status;
}
return RET_OK;
}

STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) {
if (pad_dims.size() != 4) {
MS_LOG(ERROR) << "pad dims error";
return RET_ERROR;
}
auto batch = pad_dims[NCHW_N];
auto channel = pad_dims[NCHW_C];
auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W];
auto size = batch * channel * area;
auto new_nhwc_data = new (std::nothrow) float[size];
if (new_nhwc_data == nullptr) {
MS_LOG(ERROR) << "create new nhwc data failed";
delete[] new_nhwc_data;
return RET_ERROR;
}
if (memset_s(new_nhwc_data, sizeof(float) * size, 0, sizeof(float) * size) != EOK) {
MS_LOG(ERROR) << "create new nhwc data failed";
delete[] new_nhwc_data;
return RET_ERROR;
}
auto nchw_data = reinterpret_cast<float *>(tensor->data.data());
// nchw to nhwc
for (auto i = 0; i < batch; i++) {
float *src_batch = nchw_data + i * channel * area;
float *dst_batch = new_nhwc_data + i * channel * area;
for (int j = 0; j < area; ++j) {
float *src_area = src_batch + i;
float *dst_area = dst_batch + i * channel;
for (int k = 0; k < channel; ++k) {
dst_area[k] = src_area[k * area];
}
}
}
if (memcpy_s(nchw_data, tensor->data.size(), new_nhwc_data, sizeof(float) * size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
delete[] new_nhwc_data;
return RET_ERROR;
}
delete[] new_nhwc_data;
return RET_OK;
}

STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) {
MS_ASSERT(graph != nullptr);
if (pre_not_trans_nodes.empty()) {
return RET_OK;
}
for (auto index : pre_not_trans_nodes) {
auto &cur_node = graph->nodes.at(index);
// need change axis from nchw to nhwc like concat,slice
auto ret = ChangeOpAxis(graph, cur_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ChangeOpAxis error";
return ret;
}
auto node_input_indexs = cur_node->inputIndex;
for (auto input_index : node_input_indexs) {
// weight data need trans nhwc layerout
if (!IsContain(graph->inputIndex, input_index) &&
graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) {
auto &weight_tensor = graph->allTensors.at(input_index);
auto origin_dims = weight_tensor->dims;
weight_tensor->format = Format_NHWC;
if (origin_dims.size() > 4) {
MS_LOG(ERROR) << "tensor origin tensor size error";
return RET_ERROR;
}
if (origin_dims.empty()) {
continue;
}
auto pad_dims = origin_dims;
if (origin_dims.size() == 1) {
pad_dims = {1, 1, 1, origin_dims[0]};
} else if (origin_dims.size() == 2) {
pad_dims = {1, 1, origin_dims[0], origin_dims[1]};
} else if (origin_dims.size() == 3) {
pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]};
}
if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) {
MS_LOG(ERROR) << "Convert nchw to nhwc failed";
return RET_ERROR;
}
weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]};
}
}
}
return RET_OK;
}

STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index,
std::vector<size_t> *pre_nh2nc_nodes,
std::vector<size_t> *pre_not_trans_nodes) {
MS_ASSERT(graph != nullptr);
std::vector<size_t> bfs_queue = {nc2nh_index};
// find pre node nh2nc start nodes
while (!bfs_queue.empty()) {
auto cur_node_index = bfs_queue.back();
auto &cur_node = graph->nodes.at(cur_node_index);
bfs_queue.pop_back();
auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node);
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
auto node_type = pre_node->primitive->value.type;
if (node_type == schema::PrimitiveType_Transpose && GetTransposePerm(graph, pre_node) == nhwc2nchw_perm) {
if (!IsContain(*pre_nh2nc_nodes, input_node_index)) {
pre_nh2nc_nodes->emplace_back(input_node_index);
}
} else if (IsContain(GetInsertOpList(), node_type)) {
if (!IsContain(bfs_queue, input_node_index)) {
bfs_queue.emplace_back(input_node_index);
}
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
if (pre_node_output_indexs.size() != 1) {
if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
for (auto pre_node_output_index : pre_node_output_indexs) {
MS_ASSERT(graph->nodes.size() > pre_node_output_index);
if (graph->nodes.at(pre_node_output_index)->primitive->value.type == schema::PrimitiveType_PadFusion) {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
}
}
} else {
pre_nh2nc_nodes->clear();
pre_not_trans_nodes->clear();
return RET_OK;
}
if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) {
pre_not_trans_nodes->emplace_back(cur_node_index);
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 49
mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.h View File

@@ -1,49 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H
#define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H

#include <unordered_map>
#include <set>
#include <vector>
#include <memory>
#include <string>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"

using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
class GlobalFormatTransformPass : public FormatTransPass {
public:
GlobalFormatTransformPass() = default;

~GlobalFormatTransformPass() override = default;

STATUS Run(MetaGraphT *graph) override;

protected:
STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes);

STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector<size_t> *to_do_insert_nodes,
std::vector<size_t> *pre_not_trans_nodes);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H

+ 0
- 193
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -1,193 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <memory>
#include <vector>
#include <utility>
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/common/node_util.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"

namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite {
bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count,
FormatTransNodeType *trans_type) {
for (auto input_node_index : node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
MS_ASSERT(pre_node->primitive != nullptr);
MS_ASSERT(pre_node->primitive->value != nullptr);
if (*trans_type == kNONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto perm = GetTransposePerm(graph, pre_node);
if (perm == nchw2nhwc_perm) {
*trans_type = kNCHW2NHWC;
} else if (perm == nhwc2nchw_perm) {
*trans_type = kNHWC2NCHW;
} else {
return false;
}
(*has_trans_count)++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
auto perm = GetTransposePerm(graph, pre_node);
if (perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (*trans_type != cur_type) {
return false;
} else {
(*has_trans_count)++;
}
}
}
}
return true;
}
bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
pre_type_ = kNONE;
size_t has_trans_count = 0;
if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) {
return false;
}
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
post_type_ = kNONE;
if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) {
return false;
}
if (pre_type_ == kNONE && post_type_ == kNONE) {
return false;
}
auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size();
auto total_node_count = input_node_indexes.size() + output_size;
size_t half_count = total_node_count / 2;
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
MS_ASSERT(node->primitive->value != nullptr);
MS_ASSERT(node->primitive->value.AsActivation() != nullptr);
if (node->primitive->value.AsActivation() != nullptr &&
node->primitive->value.AsActivation()->activation_type == schema::ActivationType_LEAKY_RELU) {
return has_trans_count >= half_count;
}
}
if (GetCNodeTType(*node) == schema::PrimitiveType_Split) {
return has_trans_count >= half_count;
}
return has_trans_count > half_count;
}
STATUS TransOpInsertPass::FindOutTransType() {
pre_insert_trans_type_ = kNHWC2NCHW;
post_insert_trans_type_ = kNHWC2NCHW;
if (pre_type_ == kNONE && post_type_ != kNONE) {
pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type_ != kNONE && post_type_ == kNONE) {
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type_ == kNONE && post_type_ == kNONE) {
MS_ASSERT(false);
} else {
if (pre_type_ == post_type_) {
MS_LOG(ERROR) << "Unknown error";
return RET_ERROR;
}
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
}
return RET_OK;
}

STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
bool changed = true;
int run_counts = 0;
std::vector<CNodeT *> has_insert_nodes;
while (changed && run_counts < 10) {
changed = false;
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (node == nullptr || node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
auto type = node->primitive->value.type;
if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) {
continue;
}
auto node_name = node->name;
if (!CanFusion(graph, node)) {
continue;
}
auto ret = FindOutTransType();
if (ret != RET_OK) {
MS_LOG(ERROR) << "FindOutTransType error";
return ret;
}
ret = ChangeOpAxis(graph, node);
if (ret == RET_NOT_SUPPORT) {
MS_LOG(INFO) << "not support to ChangeOpAxis";
return RET_OK;
} else if (ret != RET_OK) {
MS_LOG(INFO) << "no need to ChangeOpAxis";
return ret;
}
has_insert_nodes.push_back(node.get());
STATUS status = RET_OK;
auto input_tensor_size = (*iter)->inputIndex.size();
for (size_t i = 0; i < input_tensor_size; i++) {
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
continue;
}
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
return status;
}
}
auto output_tensor_size = (*iter)->outputIndex.size();
for (size_t i = 0; i < output_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
return status;
}
}
changed = true;
}
run_counts++;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 52
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h View File

@@ -1,52 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H
#define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H

#include <memory>
#include <vector>
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"

namespace mindspore {
namespace lite {
class TransOpInsertPass : public FormatTransPass {
public:
TransOpInsertPass() : FormatTransPass() {}

~TransOpInsertPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;

private:
bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);

STATUS FindOutTransType();

private:
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType pre_type_ = kNONE;
std::vector<int> pre_perm_;
FormatTransNodeType post_type_ = kNONE;
std::vector<int> post_perm_;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H

+ 0
- 52
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc View File

@@ -1,52 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
#include <vector>
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "tools/common/graph_util.h"
#include "src/tensor.h"

using mindspore::lite::Tensor;
namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite {
STATUS TransOpRemovePass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto type = node->primitive->value.type;
auto perm = GetTransposePerm(graph, node);
if (type == schema::PrimitiveType_Transpose && (perm == nchw2nhwc_perm || perm == nhwc2nchw_perm)) {
auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0));
// less than 4 dims can delete
if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) {
auto status = IsolateOneWayNode(graph, node.get(), true);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << node->name.c_str() << ", error: " << status;
return status;
}
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 40
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h View File

@@ -1,40 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H
#define MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H

#include <unordered_map>
#include <memory>
#include <string>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"

using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
class TransOpRemovePass : public GraphPass {
public:
TransOpRemovePass() = default;

~TransOpRemovePass() = default;

STATUS Run(MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H

+ 25
- 22
mindspore/lite/tools/optimizer/common/format_utils.cc View File

@@ -28,14 +28,15 @@
#include "ops/concat.h"
#include "ops/crop.h"
#include "ops/depth_to_space.h"
#include "ops/fused_batch_norm.h"
#include "ops/fusion/activation.h"
#include "ops/fusion/add_fusion.h"
#include "ops/fused_batch_norm.h"
#include "ops/fusion/avg_pool_fusion.h"
#include "ops/fusion/conv2d_backprop_input_fusion.h"
#include "ops/fusion/conv2d_backprop_filter_fusion.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/fusion/div_fusion.h"
#include "ops/fusion/max_pool_fusion.h"
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/pow_fusion.h"
@@ -61,6 +62,7 @@
#include "ops/space_to_depth.h"
#include "ops/split.h"
#include "ops/strided_slice.h"
#include "tools/anf_exporter/fetch_content.h"

namespace mindspore {
namespace opt {
@@ -96,9 +98,9 @@ static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{

// a certain op whose input's format is not fixed.
static const std::vector<std::string> DynamicFormatOpList = {
ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNamePowFusion, ops::kNameStridedSlice,
ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, ops::kNameCrop,
ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast};
ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNameDivFusion, ops::kNamePowFusion,
ops::kNameStridedSlice, ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion,
ops::kNameCrop, ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast};

static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}};

@@ -120,33 +122,34 @@ Format GetFormat(const CNodePtr &cnode) {
return format;
}

STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm) {
STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
MS_ASSERT(perm_node != nullptr);
if (!utils::isa<ParameterPtr>(perm_node)) {
return lite::RET_OK;
if (cnode->size() != 3) {
MS_LOG(ERROR) << "transpose op input size must be three.";
return lite::RET_ERROR;
}
auto perm_param = perm_node->cast<ParameterPtr>();
if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
if (utils::isa<CNodePtr>(cnode->input(2))) {
return lite::RET_OK;
}
auto tensor_info = perm_param->default_param()->cast<tensor::TensorPtr>();
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "default param is not a tensor.";
return lite::RET_ERROR;
lite::DataInfo data_info;
int status;
if (utils::isa<ParameterPtr>(cnode->input(2))) {
status = lite::FetchDataFromParameterNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info);
} else {
status = lite::FetchDataFromValueNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info);
}
if (tensor_info->data_type() != kNumberTypeInt && tensor_info->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "data type is error, which is " << tensor_info->data_type();
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "fetch transpose perm data failed.";
return lite::RET_ERROR;
}
auto tensor_shape = tensor_info->shape();
if (tensor_shape.empty()) {
return lite::RET_OK;
}
if (tensor_shape.size() > 1) {
if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
data_info.shape_.size() != 1) {
MS_LOG(ERROR) << "transpose perm data is invalid.";
return lite::RET_ERROR;
}
perm->resize(tensor_shape[0]);
if (memcpy_s(perm->data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
perm->resize(data_info.shape_[0]);
if (!data_info.data_.empty() &&
memcpy_s(perm->data(), data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
return lite::RET_ERROR;
}


+ 1
- 1
mindspore/lite/tools/optimizer/common/format_utils.h View File

@@ -38,7 +38,7 @@ const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap();
const std::unordered_map<int, int> &GetNC2NHAxisMap();
const std::vector<std::string> &GetDynamicFormatOpList();
Format GetFormat(const CNodePtr &cnode);
STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm);
STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm);
void RemoveIfMonad(const CNodePtr &cnode);
bool IsMonadNode(const AnfNodePtr &node);
} // namespace opt


+ 70
- 53
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -15,12 +15,13 @@
*/

#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "tools/anf_exporter/fetch_content.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "src/common/common.h"
@@ -36,47 +37,79 @@ using mindspore::lite::Tensor;
namespace mindspore::opt {
namespace {
constexpr size_t INITIAL_SIZE = 1024;
std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
if (input_tensor != nullptr) {
for (auto &i : *input_tensor) {
delete i;
i = nullptr;
}
}
if (output_tensor != nullptr) {
for (auto &i : *output_tensor) {
delete i;
i = nullptr;
}
}
}

std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &cnode, lite::converter::FmkType fmk_type) {
MS_ASSERT(CNode != nullptr);
auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>();
auto tmp_fb_node = std::make_unique<schema::CNodeT>();
lite::AnfExporter anfExporter;
anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get());
std::vector<Tensor *> input_tensors;
for (auto input_index : tmp_fb_node->inputIndex) {
auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
auto tensor_shape = tensorT->dims;
auto lite_tensor = new (std::nothrow) Tensor(
TypeId(tensorT->dataType), tensor_shape, tensorT->format,
lite::TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
if (lite_tensor == nullptr) {
MS_LOG(ERROR) << "lite tensor is nullptr";
return input_tensors;
std::vector<Tensor *> tensors;
for (size_t i = 1; i < cnode->size(); ++i) {
int status;
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(i))) {
if (!cnode->input(i)->cast<ParameterPtr>()->has_default()) {
FreeTensors(&tensors, nullptr);
return {};
}
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, false, &data_info);
} else if (utils::isa<ValueNodePtr>(cnode->input(i))) {
status = lite::FetchDataFromValueNode(cnode, i, fmk_type, false, &data_info);
} else {
MS_LOG(ERROR) << "input node is not const node.";
FreeTensors(&tensors, nullptr);
return {};
}
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as graph input
if (lite_tensor_size <= 0) {
delete lite_tensor;
return input_tensors;
if (status == lite::RET_NO_CHANGE) {
continue;
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "parser const data failed.";
FreeTensors(&tensors, nullptr);
return {};
}
if (data_info.shape_.empty() && data_info.data_.empty()) {
FreeTensors(&tensors, nullptr);
MS_LOG(DEBUG) << "input node is graph input.";
return {};
}
auto tensor_data = new (std::nothrow) uint8_t[lite_tensor_size / sizeof(char)];
auto tensor = new (std::nothrow)
Tensor(TypeId(data_info.data_type_), data_info.shape_, schema::Format(data_info.format_),
lite::TensorCategory(0, data_info.shape_.size(), TypeId(data_info.data_type_), data_info.data_.size()));
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a tensor is nullptr.";
FreeTensors(&tensors, nullptr);
return {};
}
if (data_info.data_.empty()) {
tensors.emplace_back(tensor);
continue;
}
auto tensor_data = tensor->MutableData();
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
delete lite_tensor;
return input_tensors;
MS_LOG(ERROR) << "malloc data failed.";
FreeTensors(&tensors, nullptr);
return {};
}
auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size);
if (ret != EOK) {
delete lite_tensor;
delete[](tensor_data);
MS_LOG(ERROR) << "memcpy error: " << ret;
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
if (memcpy_s(tensor_data, data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
FreeTensors(&tensors, nullptr);
return {};
}
lite_tensor->set_data(tensor_data);
input_tensors.emplace_back(lite_tensor);
tensors.emplace_back(tensor);
}
return input_tensors;
return tensors;
}

ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
@@ -229,21 +262,6 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *>
}
return lite::RET_OK;
}

void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
if (input_tensor != nullptr) {
for (auto &i : *input_tensor) {
delete i;
i = nullptr;
}
}
if (output_tensor != nullptr) {
for (auto &i : *output_tensor) {
delete i;
i = nullptr;
}
}
}
} // namespace

const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@@ -263,9 +281,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
auto input_tensors = GetCNodeInputTensors(input_cnode);
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
FreeTensors(&input_tensors, nullptr);
auto input_tensors = GetCNodeInputTensors(input_cnode, fmk_type_);
if (input_tensors.empty()) {
continue;
}
changed = true;
@@ -279,7 +296,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
FreeTensors(&input_tensors, &output_tensors);
return nullptr;
}
auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context.get());
auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context_.get());
if (lite_kernel == nullptr) {
FreeTensors(&input_tensors, &output_tensors);
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";


+ 8
- 3
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h View File

@@ -24,18 +24,23 @@
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_flags.h"

namespace mindspore {
namespace opt {
class ConstFoldPass : public PatternProcessPass {
public:
explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true)
: PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {}
explicit ConstFoldPass(lite::converter::FmkType fmk_type = lite::converter::FmkType_MS, bool multigraph = true)
: PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) {
context_ = std::make_shared<lite::InnerContext>();
context_->Init();
}
~ConstFoldPass() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
std::shared_ptr<lite::InnerContext> context;
lite::converter::FmkType fmk_type_{lite::converter::FmkType_MS};
std::shared_ptr<lite::InnerContext> context_{nullptr};
};
} // namespace opt
} // namespace mindspore


+ 96
- 82
mindspore/lite/tools/optimizer/graph/node_infershape.cc View File

@@ -18,9 +18,9 @@
#include <algorithm>
#include <memory>
#include <vector>
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "src/common/utils.h"
#include "src/ops/populate/populate_register.h"
#include "src/ops/ops_utils.h"
#include "src/runtime/infer_manager.h"
@@ -67,8 +67,8 @@ bool DuceInferFlag(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inp
}
}
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfDepend(cnode);
lite::AnfExporter::RemoveIfMakeTuple(cnode);
lite::RemoveIfDepend(cnode);
lite::RemoveIfMakeTuple(cnode);
for (size_t i = 1; i < cnode->size(); ++i) {
if (!utils::isa<CNodePtr>(cnode->input(i))) {
continue;
@@ -241,8 +241,8 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
MS_ASSERT(cnode != nullptr);
MS_ASSERT(inputs != nullptr);
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfDepend(cnode);
lite::AnfExporter::RemoveIfMakeTuple(cnode);
lite::RemoveIfDepend(cnode);
lite::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
std::vector<lite::Tensor *> const_inputs;
if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) {
@@ -288,28 +288,29 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
}

STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) {
MS_ASSERT(cnode != nullptr);
auto origin_inputs = cnode->inputs();
std::vector<AnfNodePtr> const_inputs;
for (auto &input : origin_inputs) {
if (utils::isa<CNodePtr>(input)) {
MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr);
std::vector<lite::DataInfo> data_infos;
for (size_t i = 1; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
continue;
}
const_inputs.push_back(input);
}
cnode->set_inputs(const_inputs);
auto meta_graph = std::make_unique<schema::MetaGraphT>();
meta_graph->fmkType = fmk_type_;
auto fb_node = std::make_unique<schema::CNodeT>();
lite::AnfExporter anf_exporter;
anf_exporter.set_train_flag(train_flag_);
auto status = anf_exporter.SetOpInputNode(cnode, meta_graph, fb_node.get());
cnode->set_inputs(origin_inputs);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get const inputs failed.";
return status;
STATUS status;
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(i))) {
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type_, train_flag_, &data_info);
} else {
status = lite::FetchDataFromValueNode(cnode, i, fmk_type_, train_flag_, &data_info);
}
if (status == lite::RET_NO_CHANGE) {
continue;
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "fetch const input data failed.";
return status;
}
data_infos.emplace_back(data_info);
}
return ConvertToLiteTensor(meta_graph, fb_node->inputIndex, const_ms_inputs);
return ConvertToLiteTensor(data_infos, const_ms_inputs);
}

STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) {
@@ -319,29 +320,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite:
if (!utils::isa<CNodePtr>(cnode->input(i))) {
continue;
}
auto abstract = GetCNodeInputAbstract(cnode, i);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract cnode is nullptr.";
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
MS_LOG(ERROR) << "Abstract should be anstract tensor.";
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr.";
lite::DataInfo data_info;
if (lite::FetchDataFromCNode(cnode, i, fmk_type_, train_flag_, &data_info) != lite::RET_OK) {
MS_LOG(ERROR) << "parse cnode failed.";
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
lite::Tensor *tensor = nullptr;
if (type_ptr->type_id() == kObjectTypeTensorType) {
tensor = GetCNodeTensorListVarInput(dims, abstract_tensor);
if (data_info.data_type_ == kObjectTypeTensorType) {
tensor = GetCNodeTensorListVarInput(data_info);
} else {
tensor = new (std::nothrow) lite::Tensor(TypeId(type_ptr->type_id()), dims);
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
}
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite tensor failed";
@@ -352,27 +340,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite:
return lite::RET_OK;
}

lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape,
const abstract::AbstractTensorPtr &abstract_tensor) {
MS_ASSERT(abstract_tensor != nullptr);
auto tensor_list = new (std::nothrow) lite::TensorList(shape, {});
lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(const lite::DataInfo &data_info) {
auto tensor_list = new (std::nothrow) lite::TensorList(data_info.shape_, {});
if (tensor_list == nullptr) {
MS_LOG(ERROR) << "new a lite tensor list failed";
return nullptr;
}
auto tensor_info = abstract_tensor->GetValueTrack();
if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) {
delete tensor_list;
MS_LOG(ERROR) << "nsor list abstract is invalid.";
return nullptr;
}
auto tensor_value = tensor_info->cast<tensor::TensorPtr>();
if (tensor_value->data_c() == nullptr) {
delete tensor_list;
MS_LOG(ERROR) << "cannot get tensor list abstract's info.";
return nullptr;
if (data_info.data_.empty()) {
return tensor_list;
}
auto status = tensor_list->Decode(static_cast<int *>(tensor_value->data_c()));
auto status = tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data()));
if (status != lite::RET_OK) {
delete tensor_list;
MS_LOG(ERROR) << "decode tensor list failed.";
@@ -384,41 +361,78 @@ lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape,
STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(outputs != nullptr);
auto meta_graph = std::make_unique<schema::MetaGraphT>();
meta_graph->fmkType = fmk_type_;
auto fb_node = std::make_unique<schema::CNodeT>();
lite::AnfExporter anf_exporter;
anf_exporter.set_train_flag(train_flag_);
anf_exporter.SetOpOutputNode(cnode, meta_graph, fb_node.get());
return ConvertToLiteTensor(meta_graph, fb_node->outputIndex, outputs);
std::vector<lite::DataInfo> data_infos;
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
if (tuple == nullptr) {
MS_LOG(ERROR) << "tuple is nullptr";
return lite::RET_ERROR;
}
auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) {
lite::DataInfo data_info;
data_info.node_type_ = lite::NodeType_CNode;
if (train_flag_) {
data_infos.emplace_back(data_info);
if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || CheckPrimitiveType(cnode, prim::kPrimAdam)) {
break;
}
} else {
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
return lite::RET_ERROR;
}
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
data_info.data_type_ = type;
data_infos.emplace_back(data_info);
if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) {
break;
}
}
}
} else {
lite::DataInfo data_info;
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
data_info.data_type_ = type;
data_info.node_type_ = lite::NodeType_CNode;
data_infos.emplace_back(data_info);
}
return ConvertToLiteTensor(data_infos, outputs);
}

STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::vector<uint32_t> &tensor_indexes,
STATUS NodeInferShape::ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos,
std::vector<lite::Tensor *> *tensors) {
MS_ASSERT(meta_graph != nullptr);
MS_ASSERT(tensors != nullptr);
for (auto index : tensor_indexes) {
auto tensor_t = meta_graph->allTensors.at(index).get();
auto tensor_shape = tensor_t->dims;
auto tensor_category = lite::TensorCategory(tensor_t->nodeType, tensor_t->dims.size(), TypeId(tensor_t->dataType),
tensor_t->data.size());
for (auto &data_info : data_infos) {
auto tensor_category = lite::TensorCategory(lite::NodeType(data_info.node_type_), data_info.shape_.size(),
TypeId(data_info.data_type_), data_info.data_.size());
lite::Tensor *tensor = nullptr;
if (tensor_t->dataType != kObjectTypeTensorType) {
tensor =
new (std::nothrow) lite::Tensor(TypeId(tensor_t->dataType), tensor_shape, tensor_t->format, tensor_category);
if (data_info.data_type_ != kObjectTypeTensorType) {
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_,
(schema::Format)data_info.format_, tensor_category);
} else {
tensor = new (std::nothrow) lite::TensorList(tensor_shape, std::vector<int>(), tensor_category);
tensor = new (std::nothrow) lite::TensorList(data_info.shape_, std::vector<int>(), tensor_category);
}
if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite tensor failed";
return lite::RET_ERROR;
}
auto tensor_size = tensor_t->data.size() * sizeof(char);
auto tensor_size = data_info.data_.size();
if (tensor_size > 0) {
if (tensor_t->dataType == kObjectTypeTensorType) {
if (data_info.data_type_ == kObjectTypeTensorType) {
auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor);
if (tensor_list->Decode(reinterpret_cast<const int *>(tensor_t->data.data())) != RET_OK) {
if (tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data())) != RET_OK) {
MS_LOG(ERROR) << "Decode tensorlist data failed";
return RET_ERROR;
}
@@ -429,7 +443,7 @@ STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGra
delete tensor;
return lite::RET_ERROR;
}
if (memcpy_s(tensor_data, tensor_size, tensor_t->data.data(), tensor_size) != EOK) {
if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) {
delete tensor;
delete[](tensor_data);
MS_LOG(ERROR) << "memcpy error: ";


+ 3
- 3
mindspore/lite/tools/optimizer/graph/node_infershape.h View File

@@ -22,6 +22,7 @@
#include <string>
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "tools/anf_exporter/fetch_content.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"

@@ -44,10 +45,9 @@ class NodeInferShape {
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
lite::Tensor *GetCNodeTensorListVarInput(std::vector<int> shape, const abstract::AbstractTensorPtr &abstract_tensor);
lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info);
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs);
STATUS ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::vector<uint32_t> &tensor_indexes, std::vector<lite::Tensor *> *tensors);
STATUS ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vector<lite::Tensor *> *tensors);
STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs);
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor);


+ 36
- 3
mindspore/lite/tools/optimizer/graph/transpose_strategy.cc View File

@@ -67,7 +67,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu
MS_LOG(ERROR) << "input node is invalid.";
return nullptr;
}
if (GetTransposePerm(input_cnode->input(kTransposePerm), &trans_perm) != lite::RET_OK) {
if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "transpose perm get failed.";
return nullptr;
}
@@ -142,8 +142,40 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const
return can_insert;
}

bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != 4) {
if (cnode->size() > 2) {
shape = node_infer_shape_.GetInputShape(cnode, 2);
if (shape.size() != 4 && !shape.empty()) {
return false;
}
} else {
return false;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kAxis) == nullptr) {
return false;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) {
for (size_t i = 2; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
return false;
}
}
if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
return false;
}
}
return true;
}

STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
if (shape.size() != 4) {
if (cnode->size() > 2) {
@@ -180,6 +212,7 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo
} else {
offsets.push_back(0);
}
crop_prim->set_axis(new_axis);
crop_prim->set_offsets(offsets);
}
if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
@@ -231,7 +264,7 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s
if (cnode == nullptr) {
return false;
}
if (GetTransposePerm(cnode->input(kTransposePerm), &perm) != lite::RET_OK) {
if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
return false;
}
if (perm == NH2NC) {


+ 1
- 0
mindspore/lite/tools/optimizer/graph/transpose_strategy.h View File

@@ -44,6 +44,7 @@ class TransposeStrategy {
bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info,
TransTypePair *trans_insert_info);
STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

private:
STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index);


+ 290
- 13
mindspore/lite/tools/optimizer/graph/unify_format_pass.cc View File

@@ -15,12 +15,14 @@
*/

#include "tools/optimizer/graph/unify_format_pass.h"
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>
#include "ops/op_utils.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/common/tensor_util.h"

using mindspore::lite::NCHW_SHAPE;
namespace mindspore {
@@ -37,6 +39,173 @@ bool IsSpecialType(const CNodePtr &cnode) {
}
return false;
}

STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node,
std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes,
std::set<CNodePtr> *middle_nodes) {
MS_ASSERT(func_graph != nullptr && root_node != nullptr);
MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr);
std::queue<CNodePtr> queue_nodes;
queue_nodes.push(root_node);
std::queue<bool> is_pre_nodes;
is_pre_nodes.push(true);
while (!queue_nodes.empty()) {
auto cur_node = queue_nodes.front();
auto is_pre_node = is_pre_nodes.front();
queue_nodes.pop();
is_pre_nodes.pop();
if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) {
if (is_pre_node) {
in_nodes->insert(cur_node);
} else {
out_nodes->insert(cur_node);
continue;
}
}
if (middle_nodes->find(cur_node) != middle_nodes->end()) {
continue;
}
if (in_nodes->find(cur_node) == in_nodes->end()) {
middle_nodes->insert(cur_node);
// insert pre nodes.
auto origin_inputs = cur_node->inputs();
lite::RemoveIfDepend(cur_node);
for (size_t i = 1; i < cur_node->size(); ++i) {
if (!utils::isa<CNodePtr>(cur_node->input(i))) {
continue;
}
auto cur_node_input = cur_node->input(i)->cast<CNodePtr>();
if (middle_nodes->find(cur_node_input) != middle_nodes->end() ||
in_nodes->find(cur_node_input) != in_nodes->end()) {
continue;
}
queue_nodes.push(cur_node_input);
is_pre_nodes.push(true);
}
if (CheckIsAllInputsParam(cur_node)) {
in_nodes->insert(cur_node);
}
cur_node->set_inputs(origin_inputs);
}
// insert post nodes
auto cur_node_users = func_graph->manager()->node_users()[cur_node];
for (auto &cur_node_user : cur_node_users) {
if (!utils::isa<CNodePtr>(cur_node_user.first)) {
MS_LOG(ERROR) << "post node is not cnode.";
return lite::RET_ERROR;
}
auto cur_node_post = cur_node_user.first->cast<CNodePtr>();
if (middle_nodes->find(cur_node_post) != middle_nodes->end() ||
out_nodes->find(cur_node_post) != out_nodes->end()) {
continue;
}
queue_nodes.push(cur_node_post);
is_pre_nodes.push(false);
}
if (cur_node_users.empty()) {
out_nodes->insert(cur_node);
}
}
return lite::RET_OK;
}

bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes,
const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes) {
MS_ASSERT(func_graph != nullptr);
for (auto &in_cnode : in_nodes) {
std::vector<int> perm;
if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK ||
perm != NH2NC) {
return false;
}
}
for (auto &out_cnode : out_nodes) {
std::vector<int> perm;
if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK ||
perm != NC2NH) {
return false;
}
}
auto &dynamic_ops = GetDynamicFormatOpList();
TransposeStrategy transpose_strategy;
for (auto &middle_cnode : middle_nodes) {
if (IsSpecialType(middle_cnode)) {
continue;
}
auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0));
if (!lite::IsContain(dynamic_ops, middle_node_prim->name()) ||
!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) {
return false;
}
}
return true;
}

void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type,
bool train_flag) {
MS_ASSERT(cnode != nullptr);
if (utils::isa<CNodePtr>(cnode->input(index))) {
return;
}
lite::DataInfo data_info;
int status;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
} else {
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
}
if (status != lite::RET_OK) {
return;
}
if (data_info.shape_.empty() ||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
return;
}
std::vector<int> new_shape;
if (data_info.shape_.size() == 1) {
new_shape = {1, 1, 1, data_info.shape_[0]};
} else if (data_info.shape_.size() == 2) {
new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]};
} else if (data_info.shape_.size() == 3) {
new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[2]};
}
auto size = data_info.data_.size() / sizeof(float);
std::vector<float> new_data(size);
auto new_data_ptr = static_cast<float *>(new_data.data());
auto nchw_data = reinterpret_cast<float *>(data_info.data_.data());
// nchw to nhwc
auto batch = new_shape[lite::NCHW_N];
auto channel = new_shape[lite::NCHW_C];
auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W];
for (auto i = 0; i < batch; i++) {
float *src_batch = nchw_data + i * channel * area;
float *dst_batch = new_data_ptr + i * channel * area;
for (int j = 0; j < area; ++j) {
float *src_area = src_batch + i;
float *dst_area = dst_batch + i * channel;
for (int k = 0; k < channel; ++k) {
dst_area[k] = src_area[k * area];
}
}
}
auto param_node = func_graph->add_parameter();
param_node->set_name(cnode->input(index)->fullname_with_scope());
std::vector<int64_t> shape_vec{new_shape[0], new_shape[2], new_shape[3], new_shape[1]};
auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return;
}
status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return;
}
auto tr = func_graph->manager()->Transact();
tr.SetEdge(cnode, index, param_node);
tr.Commit();
return;
}
} // namespace

void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) {
@@ -79,7 +248,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo
return false;
}
std::vector<int> post_perm;
if (GetTransposePerm(cnode->input(2), &post_perm) != lite::RET_OK) {
if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
@@ -89,7 +258,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo
if (pre_cnode == nullptr) {
return false;
}
if (GetTransposePerm(pre_cnode->input(2), &pre_perm) != lite::RET_OK) {
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return false;
}
@@ -106,7 +275,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
return lite::RET_OK;
}
std::vector<int> cur_perm;
if (GetTransposePerm(cnode->input(2), &cur_perm) != lite::RET_OK) {
if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get transpose perm failed.";
return lite::RET_ERROR;
}
@@ -116,7 +285,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
std::vector<int> post_trans_perm;
auto post_trans_node = post_node->cast<CNodePtr>();
if (GetTransposePerm(post_trans_node->input(2), &post_trans_perm) != lite::RET_OK) {
if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get post transpose node perm failed.";
return lite::RET_ERROR;
}
@@ -218,7 +387,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
MS_ASSERT(trans_insert_info != nullptr);
TransTypePair trans_info;
auto origin_inputs = cnode->inputs();
lite::AnfExporter::RemoveIfMakeTuple(cnode);
lite::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) {
cnode->set_inputs(origin_inputs);
@@ -366,8 +535,8 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
prim->AddAttr(kTransDone, MakeValue<bool>(true));
TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (!need_reset_ && (trans_info.pre_ == kNONE || trans_info.post_ == kNONE)) {
if (TransTransFusion(func_graph, cnode)) {
if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) {
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
return lite::RET_OK;
}
std::unordered_map<AnfNodePtr, AnfNodePtr> match;
@@ -401,6 +570,65 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN
return lite::RET_OK;
}

STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
std::set<CNodePtr> middle_nodes;
std::set<CNodePtr> in_nodes;
std::set<CNodePtr> out_nodes;
auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &middle_nodes);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "find an area surrounded by transpose failed.";
return status;
}
for (auto &in_cnode : in_nodes) {
if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) {
visit_transposes->insert(in_cnode);
}
}
if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) {
return lite::RET_NO_CHANGE;
}
auto node_list = TopoSort(func_graph->get_return());
std::vector<CNodePtr> middle_ops_vec;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) {
middle_ops_vec.push_back(node->cast<CNodePtr>());
middle_nodes.erase(node->cast<CNodePtr>());
}
}
for (auto &in_cnode : in_nodes) {
manager->Replace(in_cnode, in_cnode->input(1));
}
for (auto &out_cnode : out_nodes) {
manager->Replace(out_cnode, out_cnode->input(1));
}
for (auto &middle_cnode : middle_ops_vec) {
if (IsSpecialType(middle_cnode)) {
continue;
}
for (size_t i = 1; i < middle_cnode->size(); ++i) {
ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_);
}
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
status = node_infer_shape_.InferShape(middle_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, AnfNodePtr> *match) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
@@ -482,8 +710,8 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_input = return_node->inputs();
lite::AnfExporter::RemoveIfDepend(return_node);
lite::AnfExporter::RemoveIfMakeTuple(return_node);
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
for (size_t i = 1; i < return_node->size(); ++i) {
if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
continue;
@@ -511,8 +739,8 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_inputs = return_node->inputs();
lite::AnfExporter::RemoveIfDepend(return_node);
lite::AnfExporter::RemoveIfMakeTuple(return_node);
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
AbstractBasePtrList abstract_list;
bool infer_done = true;
for (size_t i = 1; i < return_node->size(); ++i) {
@@ -679,6 +907,49 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap
return true;
}

bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
std::set<CNodePtr> visit_transposes;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
return false;
}
(void)DecreaseTransposeForMultiOp(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
return false;
}
(void)DecreaseTransposeForMultiOp(sub_func_graph);
}
std::vector<int> perm;
if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK ||
perm != NH2NC) {
continue;
}
auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "global optimizer failed.";
return false;
}
}
return true;
}

bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
@@ -774,11 +1045,17 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
// if input's format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
if (!DecreaseTransposeForSingleOp(func_graph)) {
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
return false;
}
// if input format of several ops surrounded only by transpose op all can be NHWC,
// we can delete these transpose ops, and at the same time, transform these middle ops.
if (!DecreaseTransposeForMultiOp(func_graph)) {
MS_LOG(ERROR) << "run global trans insert optimizer failed.";
return false;
}
return true;
}
} // namespace opt


+ 3
- 0
mindspore/lite/tools/optimizer/graph/unify_format_pass.h View File

@@ -48,6 +48,7 @@ class UnifyFormatPass : public Pass {
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
@@ -55,6 +56,8 @@ class UnifyFormatPass : public Pass {
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);


Loading…
Cancel
Save