Browse Source

clean_code_8

tags/v1.1.0
yefeng 5 years ago
parent
commit
bd7ecf9a1a
100 changed files with 266 additions and 263 deletions
  1. +11
    -11
      mindspore/lite/src/common/file_utils.cc
  2. +2
    -2
      mindspore/lite/src/common/log_adapter.cc
  3. +6
    -6
      mindspore/lite/src/common/utils.cc
  4. +2
    -2
      mindspore/lite/src/common/utils.h
  5. +1
    -1
      mindspore/lite/src/kernel_registry.cc
  6. +7
    -5
      mindspore/lite/src/lite_kernel.cc
  7. +8
    -8
      mindspore/lite/src/lite_kernel.h
  8. +1
    -1
      mindspore/lite/src/lite_session.cc
  9. +2
    -2
      mindspore/lite/src/model_common.h
  10. +1
    -1
      mindspore/lite/src/ops/adam.cc
  11. +2
    -2
      mindspore/lite/src/ops/addn.cc
  12. +1
    -1
      mindspore/lite/src/ops/apply_momentum.cc
  13. +2
    -2
      mindspore/lite/src/ops/argmax.cc
  14. +2
    -2
      mindspore/lite/src/ops/argmin.cc
  15. +4
    -4
      mindspore/lite/src/ops/arithmetic.cc
  16. +2
    -2
      mindspore/lite/src/ops/arithmetic_self.cc
  17. +1
    -1
      mindspore/lite/src/ops/assign.cc
  18. +1
    -1
      mindspore/lite/src/ops/assign_add.cc
  19. +2
    -2
      mindspore/lite/src/ops/audio_spectrogram.cc
  20. +3
    -3
      mindspore/lite/src/ops/batch_to_space.cc
  21. +1
    -1
      mindspore/lite/src/ops/bias_grad.cc
  22. +1
    -1
      mindspore/lite/src/ops/binary_cross_entropy.cc
  23. +1
    -1
      mindspore/lite/src/ops/binary_cross_entropy_grad.cc
  24. +3
    -3
      mindspore/lite/src/ops/bn_grad.cc
  25. +2
    -2
      mindspore/lite/src/ops/broadcast_to.cc
  26. +2
    -2
      mindspore/lite/src/ops/cast.cc
  27. +2
    -2
      mindspore/lite/src/ops/concat.cc
  28. +2
    -2
      mindspore/lite/src/ops/constant_of_shape.cc
  29. +5
    -5
      mindspore/lite/src/ops/conv2d.cc
  30. +3
    -3
      mindspore/lite/src/ops/conv2d_grad_filter.cc
  31. +3
    -3
      mindspore/lite/src/ops/conv2d_grad_input.cc
  32. +2
    -2
      mindspore/lite/src/ops/crop.cc
  33. +2
    -2
      mindspore/lite/src/ops/custom_extract_features.cc
  34. +1
    -1
      mindspore/lite/src/ops/custom_normalize.cc
  35. +2
    -2
      mindspore/lite/src/ops/custom_predict.cc
  36. +4
    -4
      mindspore/lite/src/ops/deconv2d.cc
  37. +2
    -2
      mindspore/lite/src/ops/dedepthwise_conv2d.cc
  38. +3
    -3
      mindspore/lite/src/ops/depth_to_space.cc
  39. +5
    -5
      mindspore/lite/src/ops/depthwise_conv2d.cc
  40. +5
    -5
      mindspore/lite/src/ops/detection_post_process.cc
  41. +3
    -3
      mindspore/lite/src/ops/dropout.cc
  42. +2
    -2
      mindspore/lite/src/ops/dropout_grad.cc
  43. +2
    -2
      mindspore/lite/src/ops/embedding_lookup.cc
  44. +1
    -1
      mindspore/lite/src/ops/equal.cc
  45. +2
    -2
      mindspore/lite/src/ops/expand_dims.cc
  46. +2
    -2
      mindspore/lite/src/ops/fft_imag.cc
  47. +2
    -2
      mindspore/lite/src/ops/fft_real.cc
  48. +2
    -2
      mindspore/lite/src/ops/fill.cc
  49. +2
    -2
      mindspore/lite/src/ops/flatten.cc
  50. +2
    -2
      mindspore/lite/src/ops/flatten_grad.cc
  51. +2
    -2
      mindspore/lite/src/ops/full_connection.cc
  52. +2
    -2
      mindspore/lite/src/ops/fused_batchnorm.cc
  53. +2
    -2
      mindspore/lite/src/ops/gather.cc
  54. +2
    -2
      mindspore/lite/src/ops/gather_nd.cc
  55. +1
    -1
      mindspore/lite/src/ops/greater.cc
  56. +1
    -1
      mindspore/lite/src/ops/greater_equal.cc
  57. +1
    -1
      mindspore/lite/src/ops/group_conv2d_grad_input.cc
  58. +2
    -2
      mindspore/lite/src/ops/hashtable_lookup.cc
  59. +2
    -2
      mindspore/lite/src/ops/layer_norm.cc
  60. +1
    -1
      mindspore/lite/src/ops/less.cc
  61. +1
    -1
      mindspore/lite/src/ops/less_equal.cc
  62. +1
    -1
      mindspore/lite/src/ops/lsh_projection.cc
  63. +2
    -2
      mindspore/lite/src/ops/lstm.cc
  64. +2
    -2
      mindspore/lite/src/ops/matmul.cc
  65. +1
    -1
      mindspore/lite/src/ops/maximum_grad.cc
  66. +2
    -2
      mindspore/lite/src/ops/mean.cc
  67. +2
    -2
      mindspore/lite/src/ops/mfcc.cc
  68. +2
    -2
      mindspore/lite/src/ops/nchw2nhwc.cc
  69. +2
    -2
      mindspore/lite/src/ops/nhwc2nchw.cc
  70. +1
    -1
      mindspore/lite/src/ops/non_max_suppression.cc
  71. +1
    -1
      mindspore/lite/src/ops/not_equal.cc
  72. +2
    -2
      mindspore/lite/src/ops/one_hot.cc
  73. +1
    -1
      mindspore/lite/src/ops/oneslike.cc
  74. +2
    -2
      mindspore/lite/src/ops/pad.cc
  75. +2
    -2
      mindspore/lite/src/ops/pooling.cc
  76. +1
    -1
      mindspore/lite/src/ops/pooling_grad.cc
  77. +2
    -2
      mindspore/lite/src/ops/power.cc
  78. +19
    -18
      mindspore/lite/src/ops/primitive_c.cc
  79. +19
    -19
      mindspore/lite/src/ops/primitive_c.h
  80. +2
    -2
      mindspore/lite/src/ops/prior_box.cc
  81. +2
    -2
      mindspore/lite/src/ops/quant_dtype_cast.cc
  82. +2
    -2
      mindspore/lite/src/ops/range.cc
  83. +2
    -2
      mindspore/lite/src/ops/rank.cc
  84. +4
    -4
      mindspore/lite/src/ops/reduce.cc
  85. +4
    -4
      mindspore/lite/src/ops/reshape.cc
  86. +4
    -4
      mindspore/lite/src/ops/resize.cc
  87. +3
    -3
      mindspore/lite/src/ops/return.cc
  88. +2
    -2
      mindspore/lite/src/ops/reverse_sequence.cc
  89. +2
    -2
      mindspore/lite/src/ops/rfft.cc
  90. +2
    -2
      mindspore/lite/src/ops/roi_pooling.cc
  91. +2
    -2
      mindspore/lite/src/ops/scatter_nd.cc
  92. +1
    -1
      mindspore/lite/src/ops/sgd.cc
  93. +2
    -2
      mindspore/lite/src/ops/shape.cc
  94. +1
    -1
      mindspore/lite/src/ops/skip_gram.cc
  95. +5
    -5
      mindspore/lite/src/ops/slice.cc
  96. +2
    -2
      mindspore/lite/src/ops/softmax.cc
  97. +2
    -2
      mindspore/lite/src/ops/softmax_cross_entropy.cc
  98. +3
    -3
      mindspore/lite/src/ops/space_to_batch.cc
  99. +3
    -3
      mindspore/lite/src/ops/space_to_batch_nd.cc
  100. +3
    -3
      mindspore/lite/src/ops/space_to_depth.cc

+ 11
- 11
mindspore/lite/src/common/file_utils.cc View File

@@ -28,15 +28,15 @@ char *ReadFile(const char *file, size_t *size) {
return nullptr;
}
MS_ASSERT(size != nullptr);
std::string realPath = RealPath(file);
std::ifstream ifs(realPath);
std::string real_path = RealPath(file);
std::ifstream ifs(real_path);
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return nullptr;
}

if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << realPath << " open failed";
MS_LOG(ERROR) << "file: " << real_path << " open failed";
return nullptr;
}

@@ -44,7 +44,7 @@ char *ReadFile(const char *file, size_t *size) {
*size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[*size]);
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
ifs.close();
return nullptr;
}
@@ -65,21 +65,21 @@ std::string RealPath(const char *path) {
MS_LOG(ERROR) << "path is too long";
return "";
}
auto resolvedPath = std::make_unique<char[]>(PATH_MAX);
if (resolvedPath == nullptr) {
MS_LOG(ERROR) << "new resolvedPath failed";
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
return "";
}
#ifdef _WIN32
char *real_path = _fullpath(resolvedPath.get(), path, 1024);
char *real_path = _fullpath(resolved_path.get(), path, 1024);
#else
char *real_path = realpath(path, resolvedPath.get());
char *real_path = realpath(path, resolved_path.get());
#endif
if (real_path == nullptr || strlen(real_path) == 0) {
MS_LOG(ERROR) << "file path is not valid : " << path;
return "";
}
std::string res = resolvedPath.get();
std::string res = resolved_path.get();
return res;
}
} // namespace lite


+ 2
- 2
mindspore/lite/src/common/log_adapter.cc View File

@@ -26,7 +26,7 @@
namespace mindspore {
constexpr const char *ANDROID_LOG_TAG = "MS_LITE";

int EnvToInt(const char *env) {
int StrToInt(const char *env) {
if (env == nullptr) return 2;
if (strcmp(env, "0") == 0) return 0;
if (strcmp(env, "1") == 0) return 1;
@@ -37,7 +37,7 @@ int EnvToInt(const char *env) {

bool IsPrint(int level) {
static const char *env = std::getenv("GLOG_v");
static int ms_level = EnvToInt(env);
static int ms_level = StrToInt(env);
if (level < 0) {
level = 2;
}


+ 6
- 6
mindspore/lite/src/common/utils.cc View File

@@ -48,11 +48,11 @@ uint64_t GetTimeUs() {
return 0;
}
// USECS_IN_SEC *NSECS_IN_USEC;
auto retval = static_cast<uint64_t>((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC));
return retval;
auto ret_val = static_cast<uint64_t>((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC));
return ret_val;
}

std::string Remove(const std::string &from, const std::string &subStr, Mode mode) {
std::string RemoveSubStr(const std::string &from, const std::string &subStr, RemoveSubStrMode mode) {
std::string result = from;
if (mode == PREFIX) {
if (from.substr(0, subStr.length()) == subStr) {
@@ -90,8 +90,8 @@ std::vector<std::string> StrSplit(const std::string &str, const std::string &pat
}

std::vector<std::string> Tokenize(const std::string &src, const std::string &delimiters,
const Option<size_t> &maxTokenNum) {
if (maxTokenNum.IsSome() && maxTokenNum.Get() == 0) {
const Option<size_t> &max_token_num) {
if (max_token_num.IsSome() && max_token_num.Get() == 0) {
return {};
}

@@ -104,7 +104,7 @@ std::vector<std::string> Tokenize(const std::string &src, const std::string &del
break;
}
size_t delimiter = src.find_first_of(delimiters, nonDelimiter);
if (delimiter == std::string::npos || (maxTokenNum.IsSome() && tokens.size() == maxTokenNum.Get() - 1)) {
if (delimiter == std::string::npos || (max_token_num.IsSome() && tokens.size() == max_token_num.Get() - 1)) {
tokens.push_back(src.substr(nonDelimiter));
break;
}


+ 2
- 2
mindspore/lite/src/common/utils.h View File

@@ -148,10 +148,10 @@ std::vector<std::string> StrSplit(const std::string &str, const std::string &pat
std::vector<std::string> Tokenize(const std::string &src, const std::string &delimiters,
const Option<size_t> &maxTokenNum = Option<size_t>(None()));

enum Mode { PREFIX, SUFFIX, ANY };
enum RemoveSubStrMode { PREFIX, SUFFIX, ANY };

// remove redundant charactor
std::string Remove(const std::string &from, const std::string &subStr, Mode mode = ANY);
std::string RemoveSubStr(const std::string &from, const std::string &subStr, RemoveSubStrMode mode = ANY);

template <typename T>
inline Option<T> GenericParseValue(const std::string &value) {


+ 1
- 1
mindspore/lite/src/kernel_registry.cc View File

@@ -95,7 +95,7 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c
creator_arrays_[index] = creator;
}

bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &newCreators) { return false; }
bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &new_creators) { return false; }

const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; }



+ 7
- 5
mindspore/lite/src/lite_kernel.cc View File

@@ -26,7 +26,9 @@ using mindspore::lite::RET_OK;
void *LiteKernel::workspace_ = nullptr;

void LiteKernel::AllocWorkspace(size_t size) {
if (size == 0) return;
if (size == 0) {
return;
}
workspace_ = malloc(size);
if (workspace_ == nullptr) {
MS_LOG(ERROR) << "fail to alloc " << size;
@@ -74,10 +76,10 @@ int LiteKernel::FreeWorkTensor() const {

int LiteKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true);
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false);
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
@@ -279,8 +281,8 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *ke
void LiteKernelUtil::InitIOKernels(std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
// clean io kernels
kernel->SetInKernel({});
kernel->SetOutKernel({});
kernel->set_in_kernel({});
kernel->set_out_kernel({});
// find io kernels
for (auto *search_kernel : kernels) {
if (search_kernel == kernel) {


+ 8
- 8
mindspore/lite/src/lite_kernel.h View File

@@ -109,9 +109,9 @@ class LiteKernel {

virtual bool IsEval() const { return !this->train_mode_; }

virtual void SetTrainable(bool trainable = true) { this->trainable_ = trainable; }
virtual void set_trainable(bool trainable = true) { this->trainable_ = trainable; }

virtual bool IsTrainable() const { return this->trainable_; }
virtual bool is_trainable() const { return this->trainable_; }

void set_name(const std::string &name) { this->name_ = name; }

@@ -146,9 +146,9 @@ class LiteKernel {
}
}

void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
void set_in_kernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }

void SetOutKernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }
void set_out_kernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }

const std::vector<LiteKernel *> &in_kernels() const { return this->in_kernels_; }

@@ -165,18 +165,18 @@ class LiteKernel {
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }

const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
void SetWorkspaceSize(size_t value) { workspace_size_ = value; }
size_t GetWorkspaceSize() { return workspace_size_; }
void set_workspace_size(size_t value) { workspace_size_ = value; }
size_t workspace_size() { return workspace_size_; }
static void AllocWorkspace(size_t size);
static void FreeWorkspace();
void *GetWorkspace() { return workspace_; }
void *workspace() { return workspace_; }

SubGraphType subgraph_type() const { return this->subgraph_type_; }

virtual std::string ToString() const;

protected:
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()); }
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->infer_flag()); }

KernelKey desc_{};
std::string name_;


+ 1
- 1
mindspore/lite/src/lite_session.cc View File

@@ -74,7 +74,7 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit
for (size_t j = 0; j < quant_clusters->size(); j++) {
clusters.push_back(quant_clusters->Get(j));
}
dst_tensor->SetQuantClusters(clusters);
dst_tensor->set_quant_clusters(clusters);
}
}



+ 2
- 2
mindspore/lite/src/model_common.h View File

@@ -32,7 +32,7 @@ namespace mindspore::lite {
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model);

template <typename T = schema::MetaGraph, typename U = schema::CNode>
bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = 0) {
bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA_CUR) {
MS_ASSERT(model != nullptr);
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
auto *node = new (std::nothrow) Model::Node();
@@ -53,7 +53,7 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = 0) {
delete node;
return false;
}
node->primitive_->SetQuantType(static_cast<schema::QuantType>(c_node->quantType()));
node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType()));
node->name_ = c_node->name()->c_str();
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
auto count = c_node->inputIndex()->size();


+ 1
- 1
mindspore/lite/src/ops/adam.cc View File

@@ -88,7 +88,7 @@ int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tenso
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_format(inputs[0]->format());
out->set_shape({1});
}



+ 2
- 2
mindspore/lite/src/ops/addn.cc View File

@@ -83,9 +83,9 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
output->set_shape(input->shape());


+ 1
- 1
mindspore/lite/src/ops/apply_momentum.cc View File

@@ -93,7 +93,7 @@ int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<li
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_format(inputs[0]->format());
out->set_shape({1});
}



+ 2
- 2
mindspore/lite/src/ops/argmax.cc View File

@@ -71,9 +71,9 @@ int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_ERROR;
}

output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> output_shape(input->shape());


+ 2
- 2
mindspore/lite/src/ops/argmin.cc View File

@@ -69,9 +69,9 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto input_shape_size = input->shape().size();


+ 4
- 4
mindspore/lite/src/ops/arithmetic.cc View File

@@ -41,10 +41,10 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite

auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape();
auto format = input0->GetFormat();
output->SetFormat(format);
auto format = input0->format();
output->set_format(format);
output->set_data_type(input0->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
if (input_shape0.size() > 10 || input_shape1.size() > 10) {
@@ -69,7 +69,7 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite
}
in_shape1_[i] = input_shape1[i];
}
format = input0->GetFormat();
format = input0->format();
} else if (input_shape0.size() > input_shape1.size()) {
ndim_ = input_shape0.size();
auto fill_dim_num = input_shape0.size() - input_shape1.size();


+ 2
- 2
mindspore/lite/src/ops/arithmetic_self.cc View File

@@ -30,9 +30,9 @@ int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
output->set_shape(input->shape());


+ 1
- 1
mindspore/lite/src/ops/assign.cc View File

@@ -80,7 +80,7 @@ int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Ten
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_format(inputs[0]->format());
out->set_shape({1});
}
return RET_OK;


+ 1
- 1
mindspore/lite/src/ops/assign_add.cc View File

@@ -86,7 +86,7 @@ int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
output_shape[i] = x_shape[i];
}
out->set_shape(output_shape);
out->SetFormat(x->GetFormat());
out->set_format(x->format());
out->set_data_type(x->data_type());
return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/audio_spectrogram.cc View File

@@ -75,8 +75,8 @@ int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tens
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 3
- 3
mindspore/lite/src/ops/batch_to_space.cc View File

@@ -91,13 +91,13 @@ int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}

auto input = inputs.at(0);
if (input->GetFormat() != schema::Format::Format_NHWC) {
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_format(input->format());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 1
- 1
mindspore/lite/src/ops/bias_grad.cc View File

@@ -111,7 +111,7 @@ int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> out
}
out->set_shape(inshape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
out->set_format(in0->format());

return RET_OK;
}


+ 1
- 1
mindspore/lite/src/ops/binary_cross_entropy.cc View File

@@ -102,7 +102,7 @@ Registry BinaryCrossEntropyRegistry(schema::PrimitiveType_BinaryCrossEntropy, Bi
int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *out = outputs_[0];
out->SetFormat(x->GetFormat());
out->set_format(x->format());
out->set_data_type(x->data_type());
int reduction = GetReduction();
if (reduction == 1 || reduction == 2) {


+ 1
- 1
mindspore/lite/src/ops/binary_cross_entropy_grad.cc View File

@@ -109,7 +109,7 @@ Registry BinaryCrossEntropyGradRegistry(schema::PrimitiveType_BinaryCrossEntropy
int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *out = outputs_[0];
out->SetFormat(x->GetFormat());
out->set_format(x->format());
out->set_data_type(x->data_type());
std::vector<int> x_shape = x->shape();
std::vector<int> output_shape(x_shape.size());


+ 3
- 3
mindspore/lite/src/ops/bn_grad.cc View File

@@ -106,9 +106,9 @@ int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Ten
outputs[0]->set_data_type(in->data_type());
outputs[1]->set_data_type(scale->data_type());
outputs[2]->set_data_type(scale->data_type());
outputs[0]->SetFormat(in->GetFormat());
outputs[1]->SetFormat(scale->GetFormat());
outputs[2]->SetFormat(scale->GetFormat());
outputs[0]->set_format(in->format());
outputs[1]->set_format(scale->format());
outputs[2]->set_format(scale->format());
return RET_OK;
}
} // namespace lite


+ 2
- 2
mindspore/lite/src/ops/broadcast_to.cc View File

@@ -77,9 +77,9 @@ int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *>
}

auto input = inputs.at(0);
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_format(input->format());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
std::vector<int32_t> dst_shape(GetDstShape());


+ 2
- 2
mindspore/lite/src/ops/cast.cc View File

@@ -93,10 +93,10 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_LOG(ERROR) << "tensor number is error.";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_format(input->format());

output->set_data_type(static_cast<TypeId>(GetDstT()));
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/concat.cc View File

@@ -99,8 +99,8 @@ int Concat::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_PARAM_INVALID;
}
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
output->set_format(input0->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/constant_of_shape.cc View File

@@ -81,8 +81,8 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
out_tensor->set_data_type(static_cast<TypeId>(GetDataType()));
out_tensor->SetFormat(in_tensor->GetFormat());
if (!GetInferFlag()) {
out_tensor->set_format(in_tensor->format());
if (!infer_flag()) {
return RET_OK;
}
auto in_data = reinterpret_cast<int *>(in_tensor->data_c());


+ 5
- 5
mindspore/lite/src/ops/conv2d.cc View File

@@ -179,8 +179,8 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
}
attr->channelMultiplier = channel_mutiplier;

MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto input_node = inputs[kAnfPopulaterOne];
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto input_node = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
@@ -192,7 +192,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterOne];
attr->channelIn = dims[kAnfPopulaterInputNumOne];
}
}
}
@@ -372,14 +372,14 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(out_tensor != nullptr);

out_tensor->SetFormat(input_tensor->GetFormat());
out_tensor->set_format(input_tensor->format());
out_tensor->set_data_type(input_tensor->data_type());
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();

if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto in_shape = input_tensor->shape();


+ 3
- 3
mindspore/lite/src/ops/conv2d_grad_filter.cc View File

@@ -138,8 +138,8 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}

if (inputs.size() >= kAnfPopulaterThree) {
auto filter_shape = inputs[kAnfPopulaterTwo];
if (inputs.size() >= kAnfPopulaterInputNumThree) {
auto filter_shape = inputs[kAnfPopulaterInputNumTwo];
MS_ASSERT(filter_shape != nullptr);
if (filter_shape->isa<ValueNode>()) {
auto valueNode = filter_shape->cast<ValueNodePtr>();
@@ -239,7 +239,7 @@ int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tenso

out->set_shape(GetFilterShape());
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
out->set_format(in0->format());

return RET_OK;
}


+ 3
- 3
mindspore/lite/src/ops/conv2d_grad_input.cc View File

@@ -140,8 +140,8 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}

if (inputs.size() >= kAnfPopulaterThree) {
auto input_shape = inputs[kAnfPopulaterTwo];
if (inputs.size() >= kAnfPopulaterInputNumThree) {
auto input_shape = inputs[kAnfPopulaterInputNumTwo];
MS_ASSERT(input_shape != nullptr);
if (input_shape->isa<ValueNode>()) {
auto valueNode = input_shape->cast<ValueNodePtr>();
@@ -239,7 +239,7 @@ int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor
MS_ASSERT(out != nullptr);
out->set_shape(GetInputShape());
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
out->set_format(in0->format());

return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/crop.cc View File

@@ -68,9 +68,9 @@ int Crop::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;
}
outputs[0]->SetFormat(inputs[0]->GetFormat());
outputs[0]->set_format(inputs[0]->format());
outputs[0]->set_data_type(inputs[0]->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
outputs[0]->set_shape(inputs[1]->shape());


+ 2
- 2
mindspore/lite/src/ops/custom_extract_features.cc View File

@@ -50,9 +50,9 @@ int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector
MS_ASSERT(output1 != nullptr);

output0->set_data_type(kNumberTypeInt32);
output0->SetFormat(input->GetFormat());
output0->set_format(input->format());
output1->set_data_type(kNumberTypeFloat32);
output1->SetFormat(input->GetFormat());
output1->set_format(input->format());

if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";


+ 1
- 1
mindspore/lite/src/ops/custom_normalize.cc View File

@@ -48,7 +48,7 @@ int CustomNormalize::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
MS_ASSERT(output != nullptr);

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
output->set_format(input->format());

if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";


+ 2
- 2
mindspore/lite/src/ops/custom_predict.cc View File

@@ -69,10 +69,10 @@ int CustomPredict::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor

output0->set_shape(shape);
output0->set_data_type(kNumberTypeInt32);
output0->SetFormat(input->GetFormat());
output0->set_format(input->format());
output1->set_shape(shape);
output1->set_data_type(kNumberTypeFloat32);
output1->SetFormat(input->GetFormat());
output1->set_format(input->format());
return RET_OK;
}
} // namespace lite


+ 4
- 4
mindspore/lite/src/ops/deconv2d.cc View File

@@ -172,8 +172,8 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
}
attr->channelMultiplier = channel_mutiplier;

MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto input_node = inputs[kAnfPopulaterOne];
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto input_node = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
@@ -306,9 +306,9 @@ int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
int32_t input_h = input->Height();


+ 2
- 2
mindspore/lite/src/ops/dedepthwise_conv2d.cc View File

@@ -135,9 +135,9 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto in_shape = input->shape();


+ 3
- 3
mindspore/lite/src/ops/depth_to_space.cc View File

@@ -66,13 +66,13 @@ int DepthToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}

auto input = inputs.at(0);
if (input->GetFormat() != schema::Format::Format_NHWC) {
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
outputs[0]->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 5
- 5
mindspore/lite/src/ops/depthwise_conv2d.cc View File

@@ -127,8 +127,8 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
attr->channelMultiplier = channel_multiplier;

MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto inputNode = inputs[kAnfPopulaterOne];
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
auto inputNode = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
@@ -139,7 +139,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_ASSERT(abstractTensor != nullptr);
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
attr->channelIn = dims[kAnfPopulaterOne];
attr->channelIn = dims[kAnfPopulaterInputNumOne];
}
}
}
@@ -211,14 +211,14 @@ int DepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();

if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto in_shape = input->shape();


+ 5
- 5
mindspore/lite/src/ops/detection_post_process.cc View File

@@ -181,15 +181,15 @@ int DetectionPostProcess::InferShape(std::vector<lite::Tensor *> inputs_, std::v
auto num_det = outputs_.at(3);
MS_ASSERT(num_det != nullptr);

detected_boxes->SetFormat(boxes->GetFormat());
detected_boxes->set_format(boxes->format());
detected_boxes->set_data_type(kNumberTypeFloat32);
detected_classes->SetFormat(boxes->GetFormat());
detected_classes->set_format(boxes->format());
detected_classes->set_data_type(kNumberTypeFloat32);
detected_scores->SetFormat(boxes->GetFormat());
detected_scores->set_format(boxes->format());
detected_scores->set_data_type(kNumberTypeFloat32);
num_det->SetFormat(boxes->GetFormat());
num_det->set_format(boxes->format());
num_det->set_data_type(kNumberTypeFloat32);
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
const auto max_detections = GetMaxDetections();


+ 3
- 3
mindspore/lite/src/ops/dropout.cc View File

@@ -83,19 +83,19 @@ int Dropout::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
MS_ASSERT(input != nullptr);
auto output0 = outputs_.front();
MS_ASSERT(output0 != nullptr);
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
output0->set_shape(input->shape());
output0->set_data_type(input->data_type());
output0->SetFormat(input->GetFormat());
output0->set_format(input->format());

if (outputs_.size() > 1) {
auto output1 = outputs_[1];
MS_ASSERT(output1 != nullptr);
output1->set_shape(input->shape());
output1->set_data_type(input->data_type());
output1->SetFormat(input->GetFormat());
output1->set_format(input->format());
}

return RET_OK;


+ 2
- 2
mindspore/lite/src/ops/dropout_grad.cc View File

@@ -86,12 +86,12 @@ int DropoutGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
output->set_format(input->format());

return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/embedding_lookup.cc View File

@@ -67,9 +67,9 @@ int EmbeddingLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
MS_ASSERT(ids != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(params_->GetFormat());
output->set_format(params_->format());
output->set_data_type(params_->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/equal.cc View File

@@ -42,7 +42,7 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/expand_dims.cc View File

@@ -101,8 +101,8 @@ int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
MS_LOG(ERROR) << "output size is invalid";
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
int dim = GetDim();


+ 2
- 2
mindspore/lite/src/ops/fft_imag.cc View File

@@ -41,8 +41,8 @@ int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/fft_real.cc View File

@@ -41,8 +41,8 @@ int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/fill.cc View File

@@ -69,8 +69,8 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/flatten.cc View File

@@ -38,8 +38,8 @@ int Flatten::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
}

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/flatten_grad.cc View File

@@ -37,8 +37,8 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
}

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/full_connection.cc View File

@@ -69,7 +69,7 @@ int FullConnection::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
@@ -114,7 +114,7 @@ int FullConnection::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
}
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
output->set_format(input0->format());

return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/fused_batchnorm.cc View File

@@ -91,11 +91,11 @@ int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
if (outputs_.size() <= i) break;
outputs_.at(i)->set_shape(inputs_.at(i)->shape());
outputs_.at(i)->set_data_type(inputs_.at(i)->data_type());
outputs_.at(i)->SetFormat(inputs_.at(i)->GetFormat());
outputs_.at(i)->set_format(inputs_.at(i)->format());
}
if (outputs_.size() > 5) {
outputs_.at(5)->set_data_type(inputs_.at(0)->data_type());
outputs_.at(5)->SetFormat(inputs_.at(0)->GetFormat());
outputs_.at(5)->set_format(inputs_.at(0)->format());
outputs_.at(5)->set_shape({1});
}
return 0;


+ 2
- 2
mindspore/lite/src/ops/gather.cc View File

@@ -112,8 +112,8 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/gather_nd.cc View File

@@ -68,8 +68,8 @@ int GatherNd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
MS_ASSERT(output != nullptr);

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto in_shape = input->shape();


+ 1
- 1
mindspore/lite/src/ops/greater.cc View File

@@ -43,7 +43,7 @@ int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/greater_equal.cc View File

@@ -45,7 +45,7 @@ int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/group_conv2d_grad_input.cc View File

@@ -169,7 +169,7 @@ int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<T
out->set_shape(GetInputShape());

out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
out->set_format(in0->format());

return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/hashtable_lookup.cc View File

@@ -54,10 +54,10 @@ int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
hits_shape.push_back(input->DimensionSize(0));

output->set_data_type(values->data_type());
output->SetFormat(input->GetFormat());
output->set_format(input->format());
hits->set_shape(hits_shape);
hits->set_data_type(kNumberTypeUInt8);
hits->SetFormat(input->GetFormat());
hits->set_format(input->format());

if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";


+ 2
- 2
mindspore/lite/src/ops/layer_norm.cc View File

@@ -78,7 +78,7 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
MS_ASSERT(input != nullptr);
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());

if (GetElementwiseAffine() && inputs_.size() != kMultiNum) {
@@ -102,7 +102,7 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
return RET_PARAM_INVALID;
}
}
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/less.cc View File

@@ -45,7 +45,7 @@ int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/less_equal.cc View File

@@ -44,7 +44,7 @@ int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/lsh_projection.cc View File

@@ -70,7 +70,7 @@ int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor

auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeInt32);
out_tensor->SetFormat(schema::Format::Format_NHWC);
out_tensor->set_format(schema::Format::Format_NHWC);

std::vector<int> out_shape;
switch (GetLshType()) {


+ 2
- 2
mindspore/lite/src/ops/lstm.cc View File

@@ -65,9 +65,9 @@ int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_ASSERT(output != nullptr);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
outputs_[i]->set_format(input->format());
}
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/matmul.cc View File

@@ -99,8 +99,8 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
MS_ASSERT(output != nullptr);

output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
output->set_format(input0->format());
if (!infer_flag()) {
return RET_OK;
}



+ 1
- 1
mindspore/lite/src/ops/maximum_grad.cc View File

@@ -92,7 +92,7 @@ int MaximumGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
MS_ASSERT(x2 != nullptr);
MS_ASSERT(dx1 != nullptr);
MS_ASSERT(dx2 != nullptr);
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/mean.cc View File

@@ -75,8 +75,8 @@ int Mean::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
if (this->primitive_ == nullptr) {


+ 2
- 2
mindspore/lite/src/ops/mfcc.cc View File

@@ -57,8 +57,8 @@ int Mfcc::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/nchw2nhwc.cc View File

@@ -45,9 +45,9 @@ int Nchw2Nhwc::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format::Format_NHWC);
output->set_format(schema::Format::Format_NHWC);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> nchw_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/nhwc2nchw.cc View File

@@ -46,9 +46,9 @@ int Nhwc2Nchw::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format::Format_NCHW);
output->set_format(schema::Format::Format_NCHW);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> nhwc_shape = input->shape();


+ 1
- 1
mindspore/lite/src/ops/non_max_suppression.cc View File

@@ -62,7 +62,7 @@ int NonMaxSuppression::InferShape(std::vector<Tensor *> inputs_, std::vector<Ten
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeInt32);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
return RET_INFER_INVALID;
}


+ 1
- 1
mindspore/lite/src/ops/not_equal.cc View File

@@ -45,7 +45,7 @@ int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/one_hot.cc View File

@@ -112,8 +112,8 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
if (!GetInferFlag()) {
output->set_format(on_value->format());
if (!infer_flag()) {
return RET_OK;
}
const auto input_shape = input->shape();


+ 1
- 1
mindspore/lite/src/ops/oneslike.cc View File

@@ -78,7 +78,7 @@ int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
std::vector<int> output_shape(x_shape.size());
output_shape.assign(x_shape.begin(), x_shape.end());
out->set_shape(output_shape);
out->SetFormat(x->GetFormat());
out->set_format(x->format());
out->set_data_type(x->data_type());
return RET_OK;
}


+ 2
- 2
mindspore/lite/src/ops/pad.cc View File

@@ -80,9 +80,9 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/pooling.cc View File

@@ -179,8 +179,8 @@ int Pooling::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(schema::Format::Format_NHWC);
if (!GetInferFlag()) {
output->set_format(schema::Format::Format_NHWC);
if (!infer_flag()) {
return RET_OK;
}
int input_h = input->shape().at(1);


+ 1
- 1
mindspore/lite/src/ops/pooling_grad.cc View File

@@ -203,7 +203,7 @@ int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
grad_output->set_shape(output_shape);
grad_output->set_data_type(input->data_type());
// todo: temp fix
grad_output->SetFormat(input->GetFormat());
grad_output->set_format(input->format());
return RET_OK;
}
} // namespace lite


+ 2
- 2
mindspore/lite/src/ops/power.cc View File

@@ -113,8 +113,8 @@ int Power::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> output
auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr);
output_tensor->set_data_type(x_tensor->data_type());
output_tensor->SetFormat(x_tensor->GetFormat());
if (!GetInferFlag()) {
output_tensor->set_format(x_tensor->format());
if (!infer_flag()) {
return RET_OK;
}
if (exp_tensor != nullptr) {


+ 19
- 18
mindspore/lite/src/ops/primitive_c.cc View File

@@ -342,24 +342,25 @@ void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<in
}
}

schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; }
schema::PrimitiveT *PrimitiveC::primitiveT() const { return this->primitive_; }

void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; }

void PrimitiveC::SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}

void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
MS_ASSERT(index < this->input_quant_param_.size());
this->input_quant_param_[index] = input_quant_param;
}

void PrimitiveC::SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
void PrimitiveC::set_output_quant_params(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}

void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
void PrimitiveC::set_output_quant_param(const size_t &index,
const std::vector<schema::QuantParamT> &output_quant_param) {
MS_ASSERT(index < this->output_quant_param_.size());
this->output_quant_param_[index] = output_quant_param;
}
@@ -396,16 +397,16 @@ void PrimitiveC::ClearInputOutputQuantParam() {
void PrimitiveC::AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::GetInputQuantParams() const { return input_quant_param_; }
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::input_quant_params() const { return input_quant_param_; }

void PrimitiveC::AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->output_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::GetOutputQuantParams() const { return output_quant_param_; }
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::output_quant_params() const { return output_quant_param_; }

void PrimitiveC::SetQuantType(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; }
void PrimitiveC::set_quant_type(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; }

schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; }
schema::QuantType PrimitiveC::quant_type() const { return quant_type_; }

std::shared_ptr<PrimitiveC> GetReturnPrim() {
auto return_primitiveT = new (std::nothrow) schema::PrimitiveT;
@@ -463,7 +464,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect
MS_LOG(ERROR) << "make_shared PrimitiveC failed";
return nullptr;
}
primc->SetQuantType(quantType);
primc->set_quant_type(quantType);
auto ret = primc->UnPackAttr(prim, inputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UnPackAttr failed";
@@ -956,8 +957,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
}

#else
void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; }
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; }
void PrimitiveC::set_quant_type(schema::QuantType quant_type) { this->quant_type_ = quant_type; }
schema::QuantType PrimitiveC::quant_type() const { return quant_type_; }
#endif

int PrimitiveC::Type() const {
@@ -970,18 +971,18 @@ int PrimitiveC::Type() const {
return this->primitive_->value_type();
#endif
}
bool PrimitiveC::GetInferFlag() const { return this->infer_flag_; }
bool PrimitiveC::infer_flag() const { return this->infer_flag_; }

void PrimitiveC::SetInferFlag(bool flag) { this->infer_flag_ = flag; }
void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; }

int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
auto input = inputs_.front();
int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
auto output = outputs.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return 0;
}



+ 19
- 19
mindspore/lite/src/ops/primitive_c.h View File

@@ -44,9 +44,9 @@ const std::set<int> kSupportDataType = {kNumberTypeBool, kNumberTypeUInt8, kN

#ifdef PRIMITIVE_WRITEABLE
using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
constexpr int kAnfPopulaterOne = 1;
constexpr int kAnfPopulaterTwo = 2;
constexpr int kAnfPopulaterThree = 3;
constexpr int kAnfPopulaterInputNumOne = 1;
constexpr int kAnfPopulaterInputNumTwo = 2;
constexpr int kAnfPopulaterInputNumThree = 3;
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID},
@@ -75,7 +75,7 @@ class PrimitiveC : public mindspore::Primitive {

int Type() const;

schema::PrimitiveT *GetPrimitiveT() const;
schema::PrimitiveT *primitiveT() const;

void ClearPrimitiveT();

@@ -90,13 +90,13 @@ class PrimitiveC : public mindspore::Primitive {
}
}

void SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param);
void set_input_quant_params(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param);

void SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param);
void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param);

void SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param);
void set_output_quant_params(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param);

void SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param);
void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param);

bool IsInputQuantParamsInited();

@@ -106,21 +106,21 @@ class PrimitiveC : public mindspore::Primitive {

void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param);

std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const;
std::vector<std::vector<schema::QuantParamT>> input_quant_params() const;

void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param);

std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const;
std::vector<std::vector<schema::QuantParamT>> output_quant_params() const;

void SetQuantType(const schema::QuantType &quant_type);
void set_quant_type(const schema::QuantType &quant_type);

schema::QuantType GetQuantType() const;
schema::QuantType quant_type() const;

virtual int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_);
virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);

bool GetInferFlag() const;
bool infer_flag() const;

void SetInferFlag(bool flag);
void set_infer_flag(bool flag);

static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); }

@@ -162,16 +162,16 @@ class PrimitiveC {

static PrimitiveC *Create(const schema::Primitive *primitive);

bool GetInferFlag() const;
bool infer_flag() const;

void SetInferFlag(bool flag);
void set_infer_flag(bool flag);

virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);

int Type() const;

void SetQuantType(schema::QuantType quant_type);
schema::QuantType GetQuantType() const;
void set_quant_type(schema::QuantType quant_type);
schema::QuantType quant_type() const;

template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) {


+ 2
- 2
mindspore/lite/src/ops/prior_box.cc View File

@@ -138,8 +138,8 @@ int PriorBox::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
auto output = outputs_.at(0);
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
std::vector<float> different_aspect_ratios{1.0f};


+ 2
- 2
mindspore/lite/src/ops/quant_dtype_cast.cc View File

@@ -61,8 +61,8 @@ int QuantDTypeCast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
MS_ASSERT(output != nullptr);
MS_ASSERT(input->data_type() == this->GetSrcT());
output->set_data_type(static_cast<TypeId>(GetDstT()));
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
output->set_shape(input->shape());


+ 2
- 2
mindspore/lite/src/ops/range.cc View File

@@ -65,8 +65,8 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(output != nullptr);

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/rank.cc View File

@@ -43,8 +43,8 @@ int Rank::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> in_shape(1, 1);


+ 4
- 4
mindspore/lite/src/ops/reduce.cc View File

@@ -70,8 +70,8 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}

attr->keepDims = GetValue<bool>(prim.GetAttr("keep_dims"));
if (inputs.size() == kAnfPopulaterTwo) {
auto inputNode = inputs[kAnfPopulaterOne];
if (inputs.size() == kAnfPopulaterInputNumTwo) {
auto inputNode = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
@@ -150,8 +150,8 @@ int Reduce::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
if (this->primitive_ == nullptr) {


+ 4
- 4
mindspore/lite/src/ops/reshape.cc View File

@@ -47,8 +47,8 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ReshapeT();
MS_ASSERT(inputs.size() == kAnfPopulaterThree - 1);
auto inputNode = inputs[kAnfPopulaterTwo - 1];
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumThree - 1);
auto inputNode = inputs[kAnfPopulaterInputNumTwo - 1];
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
@@ -171,8 +171,8 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 4
- 4
mindspore/lite/src/ops/resize.cc View File

@@ -129,8 +129,8 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}

@@ -151,7 +151,7 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
MS_LOG(INFO) << "Resize op size can't cast int.";
return RET_INFER_INVALID;
}
switch (shape_tensor->GetFormat()) {
switch (shape_tensor->format()) {
case schema::Format_NCHW:
output_shape.push_back(data[2] * input->Height());
output_shape.push_back(data[3] * input->Width());
@@ -170,7 +170,7 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
MS_LOG(INFO) << "Resize op size can't cast float.";
return RET_INFER_INVALID;
}
switch (shape_tensor->GetFormat()) {
switch (shape_tensor->format()) {
case schema::Format_NCHW:
output_shape.push_back(data[2] * input->Height());
output_shape.push_back(data[3] * input->Width());


+ 3
- 3
mindspore/lite/src/ops/return.cc View File

@@ -70,8 +70,8 @@ int Return::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
if (this->primitive_ == nullptr) {
@@ -79,7 +79,7 @@ int Return::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
}
output->set_data_type(input->data_type());
output->set_shape(input->shape());
output->SetFormat(input->GetFormat());
output->set_format(input->format());
return RET_OK;
}
} // namespace lite


+ 2
- 2
mindspore/lite/src/ops/reverse_sequence.cc View File

@@ -64,8 +64,8 @@ int ReverseSequence::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor
MS_ASSERT(output != nullptr);

output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
output->set_shape(input->shape());


+ 2
- 2
mindspore/lite/src/ops/rfft.cc View File

@@ -52,8 +52,8 @@ int Rfft::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeComplex64);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/roi_pooling.cc View File

@@ -77,8 +77,8 @@ int ROIPooling::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}



+ 2
- 2
mindspore/lite/src/ops/scatter_nd.cc View File

@@ -56,8 +56,8 @@ int ScatterND::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
}
auto output = outputs_.front();
output->set_data_type(update->data_type());
output->SetFormat(update->GetFormat());
if (!GetInferFlag()) {
output->set_format(update->format());
if (!infer_flag()) {
return RET_OK;
}
auto shape_data = reinterpret_cast<int *>(shape->MutableData());


+ 1
- 1
mindspore/lite/src/ops/sgd.cc View File

@@ -95,7 +95,7 @@ int Sgd::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_format(inputs[0]->format());
out->set_shape({1});
}



+ 2
- 2
mindspore/lite/src/ops/shape.cc View File

@@ -42,8 +42,8 @@ int Shape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeInt32);
out_tensor->SetFormat(schema::Format::Format_NHWC);
if (!GetInferFlag()) {
out_tensor->set_format(schema::Format::Format_NHWC);
if (!infer_flag()) {
return RET_OK;
}
std::vector<int> out_shape;


+ 1
- 1
mindspore/lite/src/ops/skip_gram.cc View File

@@ -73,7 +73,7 @@ int SkipGram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
auto input = inputs_.front();
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->SetFormat(input->GetFormat());
output->set_format(input->format());
output->set_data_type(input->data_type());

if (input->data_c() == nullptr) {


+ 5
- 5
mindspore/lite/src/ops/slice.cc View File

@@ -59,8 +59,8 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (inputs.size() >= kAnfPopulaterThree) {
auto beginNode = inputs[kAnfPopulaterOne];
if (inputs.size() >= kAnfPopulaterInputNumThree) {
auto beginNode = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(beginNode != nullptr);
if (beginNode->isa<ValueNode>()) {
auto valueNode = beginNode->cast<ValueNodePtr>();
@@ -77,7 +77,7 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
}
}
}
auto sizeNode = inputs[kAnfPopulaterTwo];
auto sizeNode = inputs[kAnfPopulaterInputNumTwo];
MS_ASSERT(sizeNode != nullptr);
if (sizeNode->isa<ValueNode>()) {
auto valueNode = sizeNode->cast<ValueNodePtr>();
@@ -173,8 +173,8 @@ int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tens
}
auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
outputs[0]->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 2
- 2
mindspore/lite/src/ops/softmax.cc View File

@@ -84,8 +84,8 @@ int SoftMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
if (input->shape().size() > 5) {


+ 2
- 2
mindspore/lite/src/ops/softmax_cross_entropy.cc View File

@@ -103,14 +103,14 @@ int SoftmaxCrossEntropy::InferShape(std::vector<Tensor *> inputs, std::vector<Te
outshape.push_back(1);
out->set_shape(outshape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
out->set_format(in0->format());

if (1 < outputs.size()) {
auto *grads = outputs.at(1);
MS_ASSERT(grads != nullptr);
grads->set_shape(in0->shape());
grads->set_data_type(in0->data_type());
grads->SetFormat(in0->GetFormat());
grads->set_format(in0->format());
}
return RET_OK;
}


+ 3
- 3
mindspore/lite/src/ops/space_to_batch.cc View File

@@ -90,13 +90,13 @@ int SpaceToBatch::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}

auto input = inputs.at(0);
if (input->GetFormat() != schema::Format::Format_NHWC) {
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
return 1;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
outputs[0]->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 3
- 3
mindspore/lite/src/ops/space_to_batch_nd.cc View File

@@ -92,13 +92,13 @@ int SpaceToBatchND::InferShape(std::vector<lite::Tensor *> inputs, std::vector<l
}

auto input = inputs.at(0);
if (input->GetFormat() != schema::Format::Format_NHWC) {
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!";
return RET_ERROR;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
outputs[0]->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


+ 3
- 3
mindspore/lite/src/ops/space_to_depth.cc View File

@@ -67,13 +67,13 @@ int SpaceToDepth::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}

auto input = inputs.at(0);
if (input->GetFormat() != schema::Format::Format_NHWC) {
if (input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "space_to_depth only support NHWC now!";
return 1;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_format(input->format());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
if (!infer_flag()) {
return RET_OK;
}
auto input_shape = input->shape();


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save