Browse Source

modify icsl static check

tags/v1.1.0
yvette 5 years ago
parent
commit
2e7fef33e5
17 changed files with 102 additions and 44 deletions
  1. +9
    -5
      mindspore/lite/src/model_common.cc
  2. +6
    -0
      mindspore/lite/src/runtime/thread_pool.c
  3. +1
    -1
      mindspore/lite/src/runtime/workspace_pool.cc
  4. +22
    -9
      mindspore/lite/tools/benchmark/benchmark.cc
  5. +11
    -11
      mindspore/lite/tools/benchmark/benchmark.h
  6. +14
    -10
      mindspore/lite/tools/common/flag_parser.h
  7. +0
    -4
      mindspore/lite/tools/common/graph_util.cc
  8. +2
    -1
      mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc
  9. +6
    -0
      mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc
  10. +6
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
  11. +4
    -0
      mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc
  12. +3
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  13. +1
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc
  14. +6
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc
  15. +1
    -0
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
  16. +6
    -3
      mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc
  17. +4
    -0
      mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc

+ 9
- 5
mindspore/lite/src/model_common.cc View File

@@ -24,21 +24,25 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
return RET_ERROR;
}
subgraph->name_ = sub_graph.name()->c_str();
MS_ASSERT(sub_graph.inputIndices() != nullptr);
auto in_count = sub_graph.inputIndices()->size();
for (uint32_t i = 0; i < in_count; ++i) {
subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs<uint32_t>(i)));
subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
}
MS_ASSERT(sub_graph.outputIndices() != nullptr);
auto out_count = sub_graph.outputIndices()->size();
for (uint32_t i = 0; i < out_count; ++i) {
subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs<uint32_t>(i)));
subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
}
MS_ASSERT(sub_graph.nodeIndices() != nullptr);
auto node_count = sub_graph.nodeIndices()->size();
for (uint32_t i = 0; i < node_count; ++i) {
subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs<uint32_t>(i)));
subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
}
auto tensor_count = sub_graph.nodeIndices()->size();
MS_ASSERT(sub_graph.tensorIndices() != nullptr);
auto tensor_count = sub_graph.tensorIndices()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs<uint32_t>(i)));
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
}
model->sub_graphs_.push_back(subgraph);
return RET_OK;


+ 6
- 0
mindspore/lite/src/runtime/thread_pool.c View File

@@ -860,9 +860,15 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) {
if (ret != RET_TP_OK) {
LOG_ERROR("create thread %d failed", i);
DestroyThreadPool(thread_pool);
thread_pool = NULL;
return NULL;
}
}
if (thread_pool == NULL) {
LOG_ERROR("create thread failed");
DestroyThreadPool(thread_pool);
return NULL;
}
return thread_pool;
}



+ 1
- 1
mindspore/lite/src/runtime/workspace_pool.cc View File

@@ -109,7 +109,7 @@ void *WorkspacePool::AllocWorkSpaceMem(size_t size) {
}
}
allocList.emplace_back(alloc);
return alloc.second;
return alloc.second != nullptr ? alloc.second : nullptr;
}

void WorkspacePool::FreeWorkSpaceMem(const void *ptr) {


+ 22
- 9
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -120,6 +120,10 @@ int Benchmark::ReadInputFile() {
return RET_ERROR;
}
auto input_data = cur_tensor->MutableData();
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data is nullptr.";
return RET_ERROR;
}
memcpy(input_data, bin_buf, tensor_data_size);
}
delete[] bin_buf;
@@ -232,7 +236,7 @@ int Benchmark::CompareOutput() {
}
float mean_bias;
if (total_size != 0) {
mean_bias = total_bias / total_size * 100;
mean_bias = total_bias / float_t(total_size) * 100;
} else {
mean_bias = 0;
}
@@ -286,21 +290,26 @@ int Benchmark::CompareStringData(const std::string &name, tensor::MSTensor *tens
int Benchmark::CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
int *total_size) {
float bias = 0;
auto mutableData = tensor->MutableData();
if (mutableData == nullptr) {
MS_LOG(ERROR) << "mutableData is nullptr.";
return RET_ERROR;
}
switch (msCalibDataType) {
case TypeId::kNumberTypeFloat: {
bias = CompareData<float>(name, tensor->shape(), tensor->MutableData());
bias = CompareData<float>(name, tensor->shape(), mutableData);
break;
}
case TypeId::kNumberTypeInt8: {
bias = CompareData<int8_t>(name, tensor->shape(), tensor->MutableData());
bias = CompareData<int8_t>(name, tensor->shape(), mutableData);
break;
}
case TypeId::kNumberTypeUInt8: {
bias = CompareData<uint8_t>(name, tensor->shape(), tensor->MutableData());
bias = CompareData<uint8_t>(name, tensor->shape(), mutableData);
break;
}
case TypeId::kNumberTypeInt32: {
bias = CompareData<int32_t>(name, tensor->shape(), tensor->MutableData());
bias = CompareData<int32_t>(name, tensor->shape(), mutableData);
break;
}
default:
@@ -420,6 +429,10 @@ int Benchmark::PrintInputData() {
}
size_t print_num = std::min(input->ElementsNum(), 20);
const void *in_data = input->MutableData();
if (in_data == nullptr) {
MS_LOG(ERROR) << "in_data is nullptr.";
return RET_ERROR;
}

for (size_t j = 0; j < print_num; j++) {
if (tensor_data_type == TypeId::kNumberTypeFloat32 || tensor_data_type == TypeId::kNumberTypeFloat) {
@@ -723,7 +736,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
}
columns.push_back(iter.first);

len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->loop_count_);
len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / float_t(flags_->loop_count_));
if (len > columnLenMax.at(1)) {
columnLenMax.at(1) = len + 4;
}
@@ -760,9 +773,9 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
printf("%s\t", printBuf.c_str());
}
printf("\n");
for (size_t i = 0; i < rows.size(); i++) {
for (auto &row : rows) {
for (int j = 0; j < 5; j++) {
auto printBuf = rows[i][j];
auto printBuf = row[j];
printBuf.resize(columnLenMax.at(j), ' ');
printf("%s\t", printBuf.c_str());
}
@@ -772,7 +785,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
}

Benchmark::~Benchmark() {
for (auto iter : this->benchmark_data_) {
for (const auto &iter : this->benchmark_data_) {
delete (iter.second);
}
this->benchmark_data_.clear();


+ 11
- 11
mindspore/lite/tools/benchmark/benchmark.h View File

@@ -88,24 +88,24 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
std::string model_file_;
std::string in_data_file_;
std::vector<std::string> input_data_list_;
InDataType in_data_type_;
InDataType in_data_type_ = kBinary;
std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1;
// MarkPerformance
int loop_count_;
int num_threads_;
bool enable_fp16_;
int warm_up_loop_count_;
bool time_profiling_;
int loop_count_ = 10;
int num_threads_ = 2;
bool enable_fp16_ = false;
int warm_up_loop_count_ = 3;
bool time_profiling_ = false;
// MarkAccuracy
std::string benchmark_data_file_;
std::string benchmark_data_type_;
float accuracy_threshold_;
std::string benchmark_data_type_ = "FLOAT";
float accuracy_threshold_ = 0.5;
// Resize
std::string resize_dims_in_ = "";
std::string resize_dims_in_;
std::vector<std::vector<int>> resize_dims_;

std::string device_;
std::string device_ = "CPU";
};

class MS_API Benchmark {
@@ -149,7 +149,7 @@ class MS_API Benchmark {

// tensorData need to be converter first
template <typename T>
float CompareData(const std::string &nodeName, std::vector<int> msShape, const void *tensor_data) {
float CompareData(const std::string &nodeName, const std::vector<int> &msShape, const void *tensor_data) {
const T *msTensorData = static_cast<const T *>(tensor_data);
auto iter = this->benchmark_data_.find(nodeName);
if (iter != this->benchmark_data_.end()) {


+ 14
- 10
mindspore/lite/tools/common/flag_parser.h View File

@@ -33,9 +33,9 @@ struct Nothing {};

class FlagParser {
public:
FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", ""); }
FlagParser() { AddFlag(&FlagParser::help, helpStr, "print usage message", ""); }

virtual ~FlagParser() {}
virtual ~FlagParser() = default;

// only support read flags from command line
virtual Option<std::string> ParseFlags(int argc, const char *const *argv, bool supportUnknown = false,
@@ -60,7 +60,7 @@ class FlagParser {
// Option-type fields
template <typename Flags, typename T>
void AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo);
bool help;
bool help{};

protected:
template <typename Flags>
@@ -70,14 +70,15 @@ class FlagParser {

std::string binName;
Option<std::string> usageMsg;
std::string helpStr = "help";

private:
struct FlagInfo {
std::string flagName;
bool isRequired;
bool isBoolean;
bool isRequired = false;
bool isBoolean = false;
std::string helpInfo;
bool isParsed;
bool isParsed = false;
std::function<Option<Nothing>(FlagParser *, const std::string &)> parse;
};

@@ -218,7 +219,7 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::
return;
}

Flags *flag = dynamic_cast<Flags *>(this);
auto *flag = dynamic_cast<Flags *>(this);
if (flag == nullptr) {
return;
}
@@ -228,7 +229,10 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::
// flagItem is as a output parameter
ConstructFlag(t1, flagName, helpInfo, &flagItem);
flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
Flags *flag = dynamic_cast<Flags *>(base);
auto *flag = dynamic_cast<Flags *>(base);
if (flag == nullptr) {
return Option<Nothing>(None());
}
if (base != nullptr) {
Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
if (ret.IsNone()) {
@@ -267,7 +271,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const
return;
}

Flags *flag = dynamic_cast<Flags *>(this);
auto *flag = dynamic_cast<Flags *>(this);
if (flag == nullptr) {
MS_LOG(ERROR) << "dynamic_cast failed";
return;
@@ -278,7 +282,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const
ConstructFlag(t, flagName, helpInfo, &flagItem);
flagItem.isRequired = false;
flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> {
Flags *flag = dynamic_cast<Flags *>(base);
auto *flag = dynamic_cast<Flags *>(base);
if (base != nullptr) {
Option<T> ret = Option<std::string>(GenericParseValue<T>(value));
if (ret.IsNone()) {


+ 0
- 4
mindspore/lite/tools/common/graph_util.cc View File

@@ -605,10 +605,6 @@ std::string GetModelName(const std::string &modelFile) {
std::string modelName = modelFile;
modelName = modelName.substr(modelName.find_last_of('/') + 1);
modelName = modelName.substr(0, modelName.find_last_of('.'));

srand((unsigned)time(NULL));
modelName = modelName + std::to_string(rand());

return modelName;
}
} // namespace lite


+ 2
- 1
mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc View File

@@ -101,10 +101,11 @@ STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &p
}
fcAttr->hasBias = true;
fcAttr->axis = 1;
MS_ASSERT(matMulNode->primitive != nullptr);
MS_ASSERT(matMulNode->primitive->value != nullptr);
MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr);
transA = matMulNode->primitive->value.AsMatMul()->transposeA;
transB = matMulNode->primitive->value.AsMatMul()->transposeB;
MS_ASSERT(matMulNode->primitive->value.value != nullptr);
matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection;
matMulNode->primitive->value.value = fcAttr.release();



+ 6
- 0
mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc View File

@@ -146,6 +146,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
int shape_size = graph->allTensors.at(addBiasIndex)->dims.size();
scaleParam->axis = 0 - shape_size;
mulNode->inputIndex.push_back(addBiasIndex);
MS_ASSERT(addNode->primitive != nullptr);
MS_ASSERT(addNode->primitive->value != nullptr);
MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr);
auto activationType = addNode->primitive->value.AsAdd()->activationType;
if (activationType == ActivationType_RELU || activationType == ActivationType_RELU6 ||
activationType == ActivationType_NO_ACTIVATION) {
@@ -159,6 +162,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
} else {
// repace addnode as activation
std::unique_ptr<ActivationT> activationParam(new ActivationT());
MS_ASSERT(addNode->primitive != nullptr);
MS_ASSERT(addNode->primitive->value != nullptr);
MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr);
activationParam->type = addNode->primitive->value.AsAdd()->activationType;
addNode->primitive->value.type = schema::PrimitiveType_Activation;
addNode->primitive->value.value = activationParam.release();


+ 6
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -91,6 +91,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
MS_ASSERT(node->primitive->value != nullptr);
MS_ASSERT(node->primitive->value.AsActivation() != nullptr);
if (node->primitive->value.AsActivation() != nullptr &&
node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) {
return has_trans_count >= half_count;
@@ -198,6 +200,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
MS_ASSERT(node->primitive->value != nullptr);
auto type = node->primitive->value.type;
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
if (input1_ndim != 4) {
@@ -213,6 +216,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
}
}
if (type == PrimitiveType_Concat) {
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
auto origin_axis = node->primitive->value.AsConcat()->axis;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsConcat() == nullptr) {
@@ -222,6 +226,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
}
if (type == PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
auto axis_map = GetNc2NhAxisMap();
if (node->primitive->value.AsSplit() == nullptr) {
@@ -231,6 +236,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
}
if (type == PrimitiveType_Crop) {
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
auto origin_axis = node->primitive->value.AsCrop()->axis;
auto offsets = node->primitive->value.AsCrop()->offsets;
auto axis_map = GetNc2NhAxisMap();


+ 4
- 0
mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc View File

@@ -76,6 +76,10 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) {
}
weight->data.resize(count * sizeof(float));
const float *data_ptr = proto.data().data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "data_ptr is nullptr";
return nullptr;
}
if (::memcpy_s(weight->data.data(), count * sizeof(float), (uint8_t *)data_ptr, count * sizeof(float)) != EOK) {
MS_LOG(ERROR) << "memcpy failed";
return nullptr;


+ 3
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -157,6 +157,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(),
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
if (iter != (*nodeIter).attribute().end()) {
MS_ASSERT(iter->ints() != nullptr);
MS_ASSERT(iter->ints().begin() != nullptr);
MS_ASSERT(iter->ints().end() != nullptr);
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
}
attr->channelOut = dims[0];


+ 1
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc View File

@@ -40,6 +40,7 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N

auto onnx_node_attr = onnx_node.attribute();
for (int i = 0; i < onnx_node_attr.size(); ++i) {
MS_ASSERT(onnx_node_attr.at(i) != nullptr);
if (onnx_node_attr.at(i).name() == "axis") {
attr->axis = onnx_node_attr.at(i).i();
} else if (onnx_node_attr.at(i).name() == "p") {


+ 6
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc View File

@@ -40,6 +40,7 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
auto onnx_node_attr = onnx_node.attribute();
int32_t size = 0;
for (int i = 0; i < onnx_node_attr.size(); ++i) {
MS_ASSERT(onnx_node_attr.at(i) != nullptr);
if (onnx_node_attr.at(i).name() == "alpha") {
attr->alpha = onnx_node_attr.at(i).f();
} else if (onnx_node_attr.at(i).name() == "beta") {
@@ -51,6 +52,11 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
attr->depth_radius = size / 2;
}
}

if (size == 0) {
MS_LOG(ERROR) << "Divide-by-zero error.";
return RET_ERROR;
}
attr->alpha /= size;

op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;


+ 1
- 0
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -240,6 +240,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
lite_primitive->InferShape(input_tensors, output_tensors);
auto primitive = lite_primitive.get();
MS_ASSERT(primitive != nullptr);
MS_ASSERT(primitive->Type() != nullptr);
auto parameter =
lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);



+ 6
- 3
mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc View File

@@ -67,8 +67,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
}
// transform node means scale,bn
auto transform_node = node->cast<CNodePtr>();
if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK ||
CheckLeastInputSize(transform_node, 2) != lite::RET_OK) {
if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK || CheckLeastInputSize(transform_node, 2) != lite::RET_OK) {
return nullptr;
}

@@ -93,6 +92,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
auto trans_bias = new (std::nothrow) float[kernel_nums];
if (trans_bias == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
delete[] trans_scale;
delete[] trans_bias;
return nullptr;
}
@@ -234,8 +234,11 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
return;
}

delete[] tmp_weight_data;
if (tmp_weight_data != nullptr) {
delete[] tmp_weight_data;
}
}

const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag,
const float *trans_scale, const float *trans_bias) const {
MS_ASSERT(bias_data != nullptr);


+ 4
- 0
mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc View File

@@ -56,6 +56,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC";
return RET_ERROR;
}
MS_ASSERT(primT->value != nullptr);
MS_ASSERT(primT->value.AsTranspose() != nullptr);
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
if (perm == kPermNCHW) {
manager->Replace(transpose_cnode, transpose_cnode->input(1));
@@ -77,6 +79,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT";
return RET_ERROR;
}
MS_ASSERT(primT->value != nullptr);
MS_ASSERT(primT->value.AsTranspose() != nullptr);
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
if (perm == kPermNHWC) {
manager->Replace(transpose_cnode, transpose_cnode->input(1));


Loading…
Cancel
Save