| @@ -24,21 +24,25 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { | |||
| return RET_ERROR; | |||
| } | |||
| subgraph->name_ = sub_graph.name()->c_str(); | |||
| MS_ASSERT(sub_graph.inputIndices() != nullptr); | |||
| auto in_count = sub_graph.inputIndices()->size(); | |||
| for (uint32_t i = 0; i < in_count; ++i) { | |||
| subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs<uint32_t>(i))); | |||
| subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i)); | |||
| } | |||
| MS_ASSERT(sub_graph.outputIndices() != nullptr); | |||
| auto out_count = sub_graph.outputIndices()->size(); | |||
| for (uint32_t i = 0; i < out_count; ++i) { | |||
| subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs<uint32_t>(i))); | |||
| subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i)); | |||
| } | |||
| MS_ASSERT(sub_graph.nodeIndices() != nullptr); | |||
| auto node_count = sub_graph.nodeIndices()->size(); | |||
| for (uint32_t i = 0; i < node_count; ++i) { | |||
| subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs<uint32_t>(i))); | |||
| subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i)); | |||
| } | |||
| auto tensor_count = sub_graph.nodeIndices()->size(); | |||
| MS_ASSERT(sub_graph.tensorIndices() != nullptr); | |||
| auto tensor_count = sub_graph.tensorIndices()->size(); | |||
| for (uint32_t i = 0; i < tensor_count; ++i) { | |||
| subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs<uint32_t>(i))); | |||
| subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); | |||
| } | |||
| model->sub_graphs_.push_back(subgraph); | |||
| return RET_OK; | |||
| @@ -860,9 +860,15 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) { | |||
| if (ret != RET_TP_OK) { | |||
| LOG_ERROR("create thread %d failed", i); | |||
| DestroyThreadPool(thread_pool); | |||
| thread_pool = NULL; | |||
| return NULL; | |||
| } | |||
| } | |||
| if (thread_pool == NULL) { | |||
| LOG_ERROR("create thread failed"); | |||
| DestroyThreadPool(thread_pool); | |||
| return NULL; | |||
| } | |||
| return thread_pool; | |||
| } | |||
| @@ -109,7 +109,7 @@ void *WorkspacePool::AllocWorkSpaceMem(size_t size) { | |||
| } | |||
| } | |||
| allocList.emplace_back(alloc); | |||
| return alloc.second; | |||
| return alloc.second != nullptr ? alloc.second : nullptr; | |||
| } | |||
| void WorkspacePool::FreeWorkSpaceMem(const void *ptr) { | |||
| @@ -120,6 +120,10 @@ int Benchmark::ReadInputFile() { | |||
| return RET_ERROR; | |||
| } | |||
| auto input_data = cur_tensor->MutableData(); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(input_data, bin_buf, tensor_data_size); | |||
| } | |||
| delete[] bin_buf; | |||
| @@ -232,7 +236,7 @@ int Benchmark::CompareOutput() { | |||
| } | |||
| float mean_bias; | |||
| if (total_size != 0) { | |||
| mean_bias = total_bias / total_size * 100; | |||
| mean_bias = total_bias / float_t(total_size) * 100; | |||
| } else { | |||
| mean_bias = 0; | |||
| } | |||
| @@ -286,21 +290,26 @@ int Benchmark::CompareStringData(const std::string &name, tensor::MSTensor *tens | |||
| int Benchmark::CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias, | |||
| int *total_size) { | |||
| float bias = 0; | |||
| auto mutableData = tensor->MutableData(); | |||
| if (mutableData == nullptr) { | |||
| MS_LOG(ERROR) << "mutableData is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| switch (msCalibDataType) { | |||
| case TypeId::kNumberTypeFloat: { | |||
| bias = CompareData<float>(name, tensor->shape(), tensor->MutableData()); | |||
| bias = CompareData<float>(name, tensor->shape(), mutableData); | |||
| break; | |||
| } | |||
| case TypeId::kNumberTypeInt8: { | |||
| bias = CompareData<int8_t>(name, tensor->shape(), tensor->MutableData()); | |||
| bias = CompareData<int8_t>(name, tensor->shape(), mutableData); | |||
| break; | |||
| } | |||
| case TypeId::kNumberTypeUInt8: { | |||
| bias = CompareData<uint8_t>(name, tensor->shape(), tensor->MutableData()); | |||
| bias = CompareData<uint8_t>(name, tensor->shape(), mutableData); | |||
| break; | |||
| } | |||
| case TypeId::kNumberTypeInt32: { | |||
| bias = CompareData<int32_t>(name, tensor->shape(), tensor->MutableData()); | |||
| bias = CompareData<int32_t>(name, tensor->shape(), mutableData); | |||
| break; | |||
| } | |||
| default: | |||
| @@ -420,6 +429,10 @@ int Benchmark::PrintInputData() { | |||
| } | |||
| size_t print_num = std::min(input->ElementsNum(), 20); | |||
| const void *in_data = input->MutableData(); | |||
| if (in_data == nullptr) { | |||
| MS_LOG(ERROR) << "in_data is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t j = 0; j < print_num; j++) { | |||
| if (tensor_data_type == TypeId::kNumberTypeFloat32 || tensor_data_type == TypeId::kNumberTypeFloat) { | |||
| @@ -723,7 +736,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title, | |||
| } | |||
| columns.push_back(iter.first); | |||
| len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->loop_count_); | |||
| len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / float_t(flags_->loop_count_)); | |||
| if (len > columnLenMax.at(1)) { | |||
| columnLenMax.at(1) = len + 4; | |||
| } | |||
| @@ -760,9 +773,9 @@ int Benchmark::PrintResult(const std::vector<std::string> &title, | |||
| printf("%s\t", printBuf.c_str()); | |||
| } | |||
| printf("\n"); | |||
| for (size_t i = 0; i < rows.size(); i++) { | |||
| for (auto &row : rows) { | |||
| for (int j = 0; j < 5; j++) { | |||
| auto printBuf = rows[i][j]; | |||
| auto printBuf = row[j]; | |||
| printBuf.resize(columnLenMax.at(j), ' '); | |||
| printf("%s\t", printBuf.c_str()); | |||
| } | |||
| @@ -772,7 +785,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title, | |||
| } | |||
| Benchmark::~Benchmark() { | |||
| for (auto iter : this->benchmark_data_) { | |||
| for (const auto &iter : this->benchmark_data_) { | |||
| delete (iter.second); | |||
| } | |||
| this->benchmark_data_.clear(); | |||
| @@ -88,24 +88,24 @@ class MS_API BenchmarkFlags : public virtual FlagParser { | |||
| std::string model_file_; | |||
| std::string in_data_file_; | |||
| std::vector<std::string> input_data_list_; | |||
| InDataType in_data_type_; | |||
| InDataType in_data_type_ = kBinary; | |||
| std::string in_data_type_in_ = "bin"; | |||
| int cpu_bind_mode_ = 1; | |||
| // MarkPerformance | |||
| int loop_count_; | |||
| int num_threads_; | |||
| bool enable_fp16_; | |||
| int warm_up_loop_count_; | |||
| bool time_profiling_; | |||
| int loop_count_ = 10; | |||
| int num_threads_ = 2; | |||
| bool enable_fp16_ = false; | |||
| int warm_up_loop_count_ = 3; | |||
| bool time_profiling_ = false; | |||
| // MarkAccuracy | |||
| std::string benchmark_data_file_; | |||
| std::string benchmark_data_type_; | |||
| float accuracy_threshold_; | |||
| std::string benchmark_data_type_ = "FLOAT"; | |||
| float accuracy_threshold_ = 0.5; | |||
| // Resize | |||
| std::string resize_dims_in_ = ""; | |||
| std::string resize_dims_in_; | |||
| std::vector<std::vector<int>> resize_dims_; | |||
| std::string device_; | |||
| std::string device_ = "CPU"; | |||
| }; | |||
| class MS_API Benchmark { | |||
| @@ -149,7 +149,7 @@ class MS_API Benchmark { | |||
| // tensorData need to be converter first | |||
| template <typename T> | |||
| float CompareData(const std::string &nodeName, std::vector<int> msShape, const void *tensor_data) { | |||
| float CompareData(const std::string &nodeName, const std::vector<int> &msShape, const void *tensor_data) { | |||
| const T *msTensorData = static_cast<const T *>(tensor_data); | |||
| auto iter = this->benchmark_data_.find(nodeName); | |||
| if (iter != this->benchmark_data_.end()) { | |||
| @@ -33,9 +33,9 @@ struct Nothing {}; | |||
| class FlagParser { | |||
| public: | |||
| FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", ""); } | |||
| FlagParser() { AddFlag(&FlagParser::help, helpStr, "print usage message", ""); } | |||
| virtual ~FlagParser() {} | |||
| virtual ~FlagParser() = default; | |||
| // only support read flags from command line | |||
| virtual Option<std::string> ParseFlags(int argc, const char *const *argv, bool supportUnknown = false, | |||
| @@ -60,7 +60,7 @@ class FlagParser { | |||
| // Option-type fields | |||
| template <typename Flags, typename T> | |||
| void AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo); | |||
| bool help; | |||
| bool help{}; | |||
| protected: | |||
| template <typename Flags> | |||
| @@ -70,14 +70,15 @@ class FlagParser { | |||
| std::string binName; | |||
| Option<std::string> usageMsg; | |||
| std::string helpStr = "help"; | |||
| private: | |||
| struct FlagInfo { | |||
| std::string flagName; | |||
| bool isRequired; | |||
| bool isBoolean; | |||
| bool isRequired = false; | |||
| bool isBoolean = false; | |||
| std::string helpInfo; | |||
| bool isParsed; | |||
| bool isParsed = false; | |||
| std::function<Option<Nothing>(FlagParser *, const std::string &)> parse; | |||
| }; | |||
| @@ -218,7 +219,7 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std:: | |||
| return; | |||
| } | |||
| Flags *flag = dynamic_cast<Flags *>(this); | |||
| auto *flag = dynamic_cast<Flags *>(this); | |||
| if (flag == nullptr) { | |||
| return; | |||
| } | |||
| @@ -228,7 +229,10 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std:: | |||
| // flagItem is as a output parameter | |||
| ConstructFlag(t1, flagName, helpInfo, &flagItem); | |||
| flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> { | |||
| Flags *flag = dynamic_cast<Flags *>(base); | |||
| auto *flag = dynamic_cast<Flags *>(base); | |||
| if (flag == nullptr) { | |||
| return Option<Nothing>(None()); | |||
| } | |||
| if (base != nullptr) { | |||
| Option<T1> ret = Option<T1>(GenericParseValue<T1>(value)); | |||
| if (ret.IsNone()) { | |||
| @@ -267,7 +271,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const | |||
| return; | |||
| } | |||
| Flags *flag = dynamic_cast<Flags *>(this); | |||
| auto *flag = dynamic_cast<Flags *>(this); | |||
| if (flag == nullptr) { | |||
| MS_LOG(ERROR) << "dynamic_cast failed"; | |||
| return; | |||
| @@ -278,7 +282,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const | |||
| ConstructFlag(t, flagName, helpInfo, &flagItem); | |||
| flagItem.isRequired = false; | |||
| flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> { | |||
| Flags *flag = dynamic_cast<Flags *>(base); | |||
| auto *flag = dynamic_cast<Flags *>(base); | |||
| if (base != nullptr) { | |||
| Option<T> ret = Option<std::string>(GenericParseValue<T>(value)); | |||
| if (ret.IsNone()) { | |||
| @@ -605,10 +605,6 @@ std::string GetModelName(const std::string &modelFile) { | |||
| std::string modelName = modelFile; | |||
| modelName = modelName.substr(modelName.find_last_of('/') + 1); | |||
| modelName = modelName.substr(0, modelName.find_last_of('.')); | |||
| srand((unsigned)time(NULL)); | |||
| modelName = modelName + std::to_string(rand()); | |||
| return modelName; | |||
| } | |||
| } // namespace lite | |||
| @@ -101,10 +101,11 @@ STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &p | |||
| } | |||
| fcAttr->hasBias = true; | |||
| fcAttr->axis = 1; | |||
| MS_ASSERT(matMulNode->primitive != nullptr); | |||
| MS_ASSERT(matMulNode->primitive->value != nullptr); | |||
| MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr); | |||
| transA = matMulNode->primitive->value.AsMatMul()->transposeA; | |||
| transB = matMulNode->primitive->value.AsMatMul()->transposeB; | |||
| MS_ASSERT(matMulNode->primitive->value.value != nullptr); | |||
| matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection; | |||
| matMulNode->primitive->value.value = fcAttr.release(); | |||
| @@ -146,6 +146,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt | |||
| int shape_size = graph->allTensors.at(addBiasIndex)->dims.size(); | |||
| scaleParam->axis = 0 - shape_size; | |||
| mulNode->inputIndex.push_back(addBiasIndex); | |||
| MS_ASSERT(addNode->primitive != nullptr); | |||
| MS_ASSERT(addNode->primitive->value != nullptr); | |||
| MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr); | |||
| auto activationType = addNode->primitive->value.AsAdd()->activationType; | |||
| if (activationType == ActivationType_RELU || activationType == ActivationType_RELU6 || | |||
| activationType == ActivationType_NO_ACTIVATION) { | |||
| @@ -159,6 +162,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt | |||
| } else { | |||
| // repace addnode as activation | |||
| std::unique_ptr<ActivationT> activationParam(new ActivationT()); | |||
| MS_ASSERT(addNode->primitive != nullptr); | |||
| MS_ASSERT(addNode->primitive->value != nullptr); | |||
| MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr); | |||
| activationParam->type = addNode->primitive->value.AsAdd()->activationType; | |||
| addNode->primitive->value.type = schema::PrimitiveType_Activation; | |||
| addNode->primitive->value.value = activationParam.release(); | |||
| @@ -91,6 +91,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| MS_ASSERT(node->primitive->value != nullptr); | |||
| MS_ASSERT(node->primitive->value.AsActivation() != nullptr); | |||
| if (node->primitive->value.AsActivation() != nullptr && | |||
| node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) { | |||
| return has_trans_count >= half_count; | |||
| @@ -198,6 +200,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||
| MS_LOG(ERROR) << "node or primitive null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_ASSERT(node->primitive->value != nullptr); | |||
| auto type = node->primitive->value.type; | |||
| auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); | |||
| if (input1_ndim != 4) { | |||
| @@ -213,6 +216,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||
| } | |||
| } | |||
| if (type == PrimitiveType_Concat) { | |||
| MS_ASSERT(node->primitive->value.AsConcat() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsConcat()->axis; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsConcat() == nullptr) { | |||
| @@ -222,6 +226,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||
| node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; | |||
| } | |||
| if (type == PrimitiveType_Split) { | |||
| MS_ASSERT(node->primitive->value.AsSplit() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsSplit()->splitDim; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsSplit() == nullptr) { | |||
| @@ -231,6 +236,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||
| node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; | |||
| } | |||
| if (type == PrimitiveType_Crop) { | |||
| MS_ASSERT(node->primitive->value.AsCrop() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsCrop()->axis; | |||
| auto offsets = node->primitive->value.AsCrop()->offsets; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| @@ -76,6 +76,10 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { | |||
| } | |||
| weight->data.resize(count * sizeof(float)); | |||
| const float *data_ptr = proto.data().data(); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "data_ptr is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (::memcpy_s(weight->data.data(), count * sizeof(float), (uint8_t *)data_ptr, count * sizeof(float)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy failed"; | |||
| return nullptr; | |||
| @@ -157,6 +157,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | |||
| if (iter != (*nodeIter).attribute().end()) { | |||
| MS_ASSERT(iter->ints() != nullptr); | |||
| MS_ASSERT(iter->ints().begin() != nullptr); | |||
| MS_ASSERT(iter->ints().end() != nullptr); | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| attr->channelOut = dims[0]; | |||
| @@ -40,6 +40,7 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| auto onnx_node_attr = onnx_node.attribute(); | |||
| for (int i = 0; i < onnx_node_attr.size(); ++i) { | |||
| MS_ASSERT(onnx_node_attr.at(i) != nullptr); | |||
| if (onnx_node_attr.at(i).name() == "axis") { | |||
| attr->axis = onnx_node_attr.at(i).i(); | |||
| } else if (onnx_node_attr.at(i).name() == "p") { | |||
| @@ -40,6 +40,7 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| auto onnx_node_attr = onnx_node.attribute(); | |||
| int32_t size = 0; | |||
| for (int i = 0; i < onnx_node_attr.size(); ++i) { | |||
| MS_ASSERT(onnx_node_attr.at(i) != nullptr); | |||
| if (onnx_node_attr.at(i).name() == "alpha") { | |||
| attr->alpha = onnx_node_attr.at(i).f(); | |||
| } else if (onnx_node_attr.at(i).name() == "beta") { | |||
| @@ -51,6 +52,11 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| attr->depth_radius = size / 2; | |||
| } | |||
| } | |||
| if (size == 0) { | |||
| MS_LOG(ERROR) << "Divide-by-zero error."; | |||
| return RET_ERROR; | |||
| } | |||
| attr->alpha /= size; | |||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||
| @@ -240,6 +240,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| lite_primitive->InferShape(input_tensors, output_tensors); | |||
| auto primitive = lite_primitive.get(); | |||
| MS_ASSERT(primitive != nullptr); | |||
| MS_ASSERT(primitive->Type() != nullptr); | |||
| auto parameter = | |||
| lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); | |||
| @@ -67,8 +67,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||
| } | |||
| // transform node means scale,bn | |||
| auto transform_node = node->cast<CNodePtr>(); | |||
| if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK || | |||
| CheckLeastInputSize(transform_node, 2) != lite::RET_OK) { | |||
| if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK || CheckLeastInputSize(transform_node, 2) != lite::RET_OK) { | |||
| return nullptr; | |||
| } | |||
| @@ -93,6 +92,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||
| auto trans_bias = new (std::nothrow) float[kernel_nums]; | |||
| if (trans_bias == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| delete[] trans_scale; | |||
| delete[] trans_bias; | |||
| return nullptr; | |||
| } | |||
| @@ -234,8 +234,11 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne | |||
| return; | |||
| } | |||
| delete[] tmp_weight_data; | |||
| if (tmp_weight_data != nullptr) { | |||
| delete[] tmp_weight_data; | |||
| } | |||
| } | |||
| const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, | |||
| const float *trans_scale, const float *trans_bias) const { | |||
| MS_ASSERT(bias_data != nullptr); | |||
| @@ -56,6 +56,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(primT->value != nullptr); | |||
| MS_ASSERT(primT->value.AsTranspose() != nullptr); | |||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||
| if (perm == kPermNCHW) { | |||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||
| @@ -77,6 +79,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(primT->value != nullptr); | |||
| MS_ASSERT(primT->value.AsTranspose() != nullptr); | |||
| std::vector<int32_t> perm = primT->value.AsTranspose()->perm; | |||
| if (perm == kPermNHWC) { | |||
| manager->Replace(transpose_cnode, transpose_cnode->input(1)); | |||