Browse Source

!7432 Adam + Sparse softmax and bug fix

Merge pull request !7432 from yonibaehr/export
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4843f6aba0
71 changed files with 1558 additions and 161 deletions
  1. +8
    -0
      mindspore/lite/nnacl/fp32/activation.c
  2. +4
    -3
      mindspore/lite/nnacl/fp32/activation.h
  3. +4
    -4
      mindspore/lite/nnacl/fp32/one_hot.c
  4. +126
    -36
      mindspore/lite/nnacl/fp32_grad/gemm.c
  5. +5
    -1
      mindspore/lite/nnacl/fp32_grad/optimizer.h
  6. +3
    -1
      mindspore/lite/schema/model.fbs
  7. +31
    -2
      mindspore/lite/schema/ops.fbs
  8. +4
    -0
      mindspore/lite/src/ops/activation.cc
  9. +5
    -1
      mindspore/lite/src/ops/activation_grad.cc
  10. +91
    -0
      mindspore/lite/src/ops/adam.cc
  11. +47
    -0
      mindspore/lite/src/ops/adam.h
  12. +20
    -2
      mindspore/lite/src/ops/addn.cc
  13. +1
    -4
      mindspore/lite/src/ops/apply_momentum.cc
  14. +0
    -1
      mindspore/lite/src/ops/apply_momentum.h
  15. +82
    -0
      mindspore/lite/src/ops/assign.cc
  16. +43
    -0
      mindspore/lite/src/ops/assign.h
  17. +2
    -2
      mindspore/lite/src/ops/bn_grad.cc
  18. +4
    -1
      mindspore/lite/src/ops/conv2d_grad_input.cc
  19. +24
    -0
      mindspore/lite/src/ops/exp.cc
  20. +1
    -0
      mindspore/lite/src/ops/exp.h
  21. +172
    -0
      mindspore/lite/src/ops/group_conv2d_grad_input.cc
  22. +79
    -0
      mindspore/lite/src/ops/group_conv2d_grad_input.h
  23. +25
    -1
      mindspore/lite/src/ops/neg.cc
  24. +1
    -0
      mindspore/lite/src/ops/neg.h
  25. +31
    -0
      mindspore/lite/src/ops/one_hot.cc
  26. +1
    -1
      mindspore/lite/src/ops/one_hot.h
  27. +36
    -3
      mindspore/lite/src/ops/primitive_c.cc
  28. +6
    -4
      mindspore/lite/src/ops/primitive_c.h
  29. +8
    -1
      mindspore/lite/src/ops/real_div.cc
  30. +1
    -4
      mindspore/lite/src/ops/real_div.h
  31. +63
    -1
      mindspore/lite/src/ops/slice.cc
  32. +4
    -3
      mindspore/lite/src/ops/slice.h
  33. +29
    -0
      mindspore/lite/src/ops/squeeze.cc
  34. +4
    -3
      mindspore/lite/src/ops/squeeze.h
  35. +34
    -9
      mindspore/lite/src/ops/tile.cc
  36. +2
    -0
      mindspore/lite/src/populate_parameter.cc
  37. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc
  38. +23
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc
  39. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
  40. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
  41. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc
  42. +1
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h
  43. +118
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc
  44. +44
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h
  45. +1
    -8
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc
  46. +91
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc
  47. +39
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h
  48. +5
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc
  49. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc
  50. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc
  51. +3
    -0
      mindspore/lite/src/train/train_model.cc
  52. +64
    -1
      mindspore/lite/src/train/train_populate_parameter.cc
  53. +34
    -10
      mindspore/lite/src/train/train_session.cc
  54. +5
    -1
      mindspore/lite/src/train/train_session.h
  55. +6
    -6
      mindspore/lite/test/common/common_test.h
  56. +26
    -14
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc
  57. +4
    -3
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc
  58. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms
  59. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effnetb0_fwd_fuse.ms
  60. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface1.ms
  61. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1.hwc_normalized_f32
  62. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_conf.f32
  63. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_landms.f32
  64. BIN
      mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_loc.f32
  65. +20
    -1
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  66. +33
    -4
      mindspore/lite/tools/anf_importer/import_from_protobuf.cc
  67. +2
    -0
      mindspore/lite/tools/common/node_util.cc
  68. +1
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc
  69. +3
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
  70. +10
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc
  71. +12
    -15
      mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc

+ 8
- 0
mindspore/lite/nnacl/fp32/activation.c View File

@@ -117,6 +117,14 @@ int HSwish(const float *src, int length, float *dst) {
return NNACL_OK;
}

int HSigmoid(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6);
dst[i] = relu6 / 6;
}
return NNACL_OK;
}

int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) {
if (max_val <= min_val) {
return NNACL_ERR;


+ 4
- 3
mindspore/lite/nnacl/fp32/activation.h View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_ACTIVATION_H_
#define MINDSPORE_LITE_NNACL_ACTIVATION_H_
#ifndef MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_
#define MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_

#include <math.h>
#include "nnacl/op_base.h"
@@ -36,9 +36,10 @@ int Fp32Relu6(const float *src, int length, float *dst);
int LRelu(const float *src, int length, float *dst, float alpha);
int Sigmoid(const float *src, int length, float *dst);
int Tanh(const float *src, int length, float *dst);
int HSigmoid(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_ACTIVATION_H_
#endif // MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_

+ 4
- 4
mindspore/lite/nnacl/fp32/one_hot.c View File

@@ -31,14 +31,14 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par
int i, j, k;
for (i = tid; i < outer_size; i += thread_num) {
float *output_ptr = output + i * depth * inner_size;
for (k = 0; k < depth; k++) {
for (j = 0; j < inner_size; j++) {
for (k = 0; k < inner_size; k++) {
int index = indices[i * inner_size + k];
for (j = 0; j < depth; j++) {
*output_ptr = off_value;
int index = indices[i * inner_size + j];
if (index >= depth) {
return NNACL_ERRCODE_INDEX_OUT_OF_RANGE;
}
if (index == k) {
if (index == j) {
*output_ptr = on_value;
}
output_ptr++;


+ 126
- 36
mindspore/lite/nnacl/fp32_grad/gemm.c View File

@@ -15,27 +15,52 @@
*/

#include "nnacl/fp32_grad/gemm.h"
#include <string.h>

static void gemm_not_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = N % block_size;
int block_c4 = N - block_mod;

static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c,
int ldc) {
int i, j, k;
for (i = 0; i < M; ++i) {
for (k = 0; k < K; ++k) {
float a = alpha * mat_a[i * lda + k];
for (j = 0; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_B[k * ldb + j];
for (j = 0; j < block_c4; j += block_size) {
float *b = &mat_b[k * ldb + j];
float *c = &mat_c[i * ldc + j];
c[0] += a * b[0];
c[1] += a * b[1];
c[2] += a * b[2];
c[3] += a * b[3];
}
for (; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_b[k * ldb + j];
}
}
}
}

static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_not_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = K % block_size;
int block_c4 = K - block_mod;

int i, j, k;
for (i = 0; i < M; ++i) {
for (j = 0; j < N; ++j) {
float sum = 0;
for (k = 0; k < K; ++k) {
for (k = 0; k < block_c4; k += block_size) {
float *a = &mat_a[i * lda + k];
float *b = &mat_b[j * ldb + k];
sum += alpha * a[0] * b[0];
sum += alpha * a[1] * b[1];
sum += alpha * a[2] * b[2];
sum += alpha * a[3] * b[3];
}
for (; k < K; ++k) {
sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k];
}
mat_c[i * ldc + j] += sum;
@@ -43,23 +68,85 @@ static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, flo
}
}

static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = N % block_size;
int block_c4 = N - block_mod;

int i, j, k;
for (i = 0; i < M; ++i) {
for (k = 0; k < K; ++k) {
float a = alpha * mat_a[k * lda + i];
for (j = 0; j < N; ++j) {
for (j = 0; j < block_c4; j += block_size) {
float *b = &mat_b[k * ldb + j];
float *c = &mat_c[i * ldc + j];
c[0] += a * b[0];
c[1] += a * b[1];
c[2] += a * b[2];
c[3] += a * b[3];
}
for (; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_b[k * ldb + j];
}
}
}
}

static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
int i, j, k;
for (i = 0; i < M; ++i) {
const int block_size = 4;
int k_block_mod = K % block_size;
int k_block_c4 = K - k_block_mod;

int m_block_mod = M % block_size;
int m_block_c4 = M - m_block_mod;

for (i = 0; i < m_block_c4; i += block_size) {
for (j = 0; j < N; ++j) {
float sum0 = 0;
float sum1 = 0;
float sum2 = 0;
float sum3 = 0;

for (k = 0; k < k_block_c4; k += block_size) {
float *b = &mat_b[j * ldb + k];
sum0 += alpha * mat_a[i + k * lda] * b[0];
sum0 += alpha * mat_a[i + (k + 1) * lda] * b[1];
sum0 += alpha * mat_a[i + (k + 2) * lda] * b[2];
sum0 += alpha * mat_a[i + (k + 3) * lda] * b[3];

sum1 += alpha * mat_a[i + 1 + k * lda] * b[0];
sum1 += alpha * mat_a[i + 1 + (k + 1) * lda] * b[1];
sum1 += alpha * mat_a[i + 1 + (k + 2) * lda] * b[2];
sum1 += alpha * mat_a[i + 1 + (k + 3) * lda] * b[3];

sum2 += alpha * mat_a[i + 2 + k * lda] * b[0];
sum2 += alpha * mat_a[i + 2 + (k + 1) * lda] * b[1];
sum2 += alpha * mat_a[i + 2 + (k + 2) * lda] * b[2];
sum2 += alpha * mat_a[i + 2 + (k + 3) * lda] * b[3];

sum3 += alpha * mat_a[i + 3 + k * lda] * b[0];
sum3 += alpha * mat_a[i + 3 + (k + 1) * lda] * b[1];
sum3 += alpha * mat_a[i + 3 + (k + 2) * lda] * b[2];
sum3 += alpha * mat_a[i + 3 + (k + 3) * lda] * b[3];
}
for (; k < K; ++k) {
float *b = &mat_b[j * ldb + k];
sum0 += alpha * mat_a[i + (k * lda)] * b[0];
sum1 += alpha * mat_a[i + 1 + (k * lda)] * b[0];
sum2 += alpha * mat_a[i + 2 + (k * lda)] * b[0];
sum3 += alpha * mat_a[i + 3 + (k * lda)] * b[0];
}
mat_c[i * ldc + j] += sum0;
mat_c[(i + 1) * ldc + j] += sum1;
mat_c[(i + 2) * ldc + j] += sum2;
mat_c[(i + 3) * ldc + j] += sum3;
}
}
// no more block of 4x4
for (; i < M; ++i) {
for (j = 0; j < N; ++j) {
float sum = 0;
for (k = 0; k < K; ++k) {
@@ -74,34 +161,37 @@ static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, flo
// M - number of rows of matrix a
// N - number of cols of matrix b
// K - number of cols of matrix a

// lda - fast dim of matrix a
// ldb - fast dim of matrix b
// ldc - fast dim of matrix c
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
int ldb, float beta, float *mat_c, int ldc) {
if (beta >= 0.f && beta <= 0.f) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
mat_c[i * ldc + j] = 0;
}
}
memset(mat_c, 0, M * N * sizeof(float));
} else if (beta < 1.f || beta > 1.f) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
mat_c[i * ldc + j] *= beta;
}
const int block_size = 4;
const int size = M * N;
int block_mod = size % block_size;
int block_c4 = size - block_mod;
int i;
for (i = 0; i < block_c4; i += block_size) {
float *c = &mat_c[i];
c[0] *= beta;
c[1] *= beta;
c[2] *= beta;
c[3] *= beta;
}
}

int t;

for (t = 0; t < M; ++t) {
if (!transpose_a && !transpose_b) {
gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else if (transpose_a && !transpose_b) {
gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else if (!transpose_a && transpose_b) {
gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else {
gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc);
for (; i < size; ++i) {
mat_c[i] *= beta;
}
}
if (transpose_a && transpose_b) {
gemm_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else if (!transpose_a && !transpose_b) {
gemm_not_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else if (!transpose_a && transpose_b) {
gemm_not_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else {
gemm_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
}
}

+ 5
- 1
mindspore/lite/nnacl/fp32_grad/optimizer.h View File

@@ -21,7 +21,6 @@

typedef struct ApplyMomentumParameter {
OpParameter op_parameter_;
bool use_locking_;
bool use_nesterov_;
float grad_scale_;
} ApplyMomentumParameter;
@@ -33,4 +32,9 @@ typedef struct SgdParameter {
float weight_decay_;
} SgdParameter;

typedef struct AdamParameter {
OpParameter op_parameter_;
bool use_nesterov_;
} AdamParameter;

#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_

+ 3
- 1
mindspore/lite/schema/model.fbs View File

@@ -182,7 +182,7 @@ union PrimitiveType {
Conv2DGradInput,
PoolingGrad,
BNGrad,
BNGradInput,
Assign,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
@@ -217,6 +217,8 @@ union PrimitiveType {
FftReal,
FftImag,
Sgd,
Adam,
GroupConv2DGradInput,
}

enum QuantType: int {


+ 31
- 2
mindspore/lite/schema/ops.fbs View File

@@ -224,7 +224,29 @@ table Conv2DGradInput {
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}table FusedBatchNorm {
}

table GroupConv2DGradInput {
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
kernelW: int;
kernelH: int;
strideW: int;
strideH: int;
padMode: PadMode;
padUp: int;
padDown: int;
padLeft: int;
padRight: int;
dilateW: int;
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}

table FusedBatchNorm {
epsilon: float = 0.00001; // eg. epsilon=0.001
momentum: float = 0.9;
spatial: int = 1;
@@ -901,7 +923,6 @@ table TupleGetItem {

table ApplyMomentum {
gradientScale: float;
useLocking: bool;
useNesterov: bool;
}

@@ -911,6 +932,14 @@ table Sgd {
useNesterov: bool;
}

table Adam {
useNesterov: bool;
}

table Assign {
}


table Where{
condition: [bool];
}


+ 4
- 0
mindspore/lite/src/ops/activation.cc View File

@@ -50,6 +50,10 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "HSwish") {
attr->type = schema::ActivationType_HSWISH;
} else if (prim.name() == "HSigmoid") {
attr->type = schema::ActivationType_HSIGMOID;
}
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {


+ 5
- 1
mindspore/lite/src/ops/activation_grad.cc View File

@@ -43,8 +43,12 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP
attr->type = schema::ActivationType_RELU;
} else if (prim.name() == "SigmoidGrad") {
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "Relu6Grad") {
} else if (prim.name() == "ReLU6Grad") {
attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "HSigmoidGrad") {
attr->type = schema::ActivationType_HSIGMOID;
} else if (prim.name() == "HSwishGrad") {
attr->type = schema::ActivationType_HSWISH;
}
attr->alpha = 0; // alpha;
this->primitive_->value.value = attr.release();


+ 91
- 0
mindspore/lite/src/ops/adam.cc View File

@@ -0,0 +1,91 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/adam.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; }
int Adam::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Adam;
}
if (this->primitive_->value.type != schema::PrimitiveType_Adam) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = std::make_unique<schema::AdamT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));

this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); }
int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Adam();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Adam return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif

int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (10 != inputs.size()) {
MS_LOG(ERROR) << "Adam should have at least 8 input tensors";
return RET_ERROR;
}

if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() ||
inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 ||
inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 ||
inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}

return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 47
- 0
mindspore/lite/src/ops/adam.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_
#define MINDSPORE_LITE_SRC_OPS_ADAM_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Adam : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Adam, PrimitiveC);
Adam() = default;
explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Adam() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetUseNesterov() const;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_

+ 20
- 2
mindspore/lite/src/ops/addn.cc View File

@@ -82,8 +82,11 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());

// make sure all elements have the same size or 1 (broadcasting) in all dimensions
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
if (inputs.at(i)->shape().size() != inputs.at(0)->shape().size()) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
@@ -93,7 +96,22 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
}
}

output->set_shape(input->shape());
for (size_t d = 0; d < input->shape().size(); ++d) {
int max_dim = input->shape().at(d);
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape().at(d) > max_dim) {
max_dim = inputs.at(i)->shape().at(d);
}
}
for (size_t i = 1; i < inputs.size(); ++i) {
if ((inputs.at(0)->shape().at(d) != max_dim) && (inputs.at(0)->shape().at(d) != 1)) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
}
output->shape()[d] = max_dim; // set the biggest dimension in the output tensor
}

return RET_OK;
}
} // namespace lite


+ 1
- 4
mindspore/lite/src/ops/apply_momentum.cc View File

@@ -18,7 +18,6 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; }

int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
@@ -41,7 +40,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
return RET_ERROR;
}
attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale"));
attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking"));
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));

this->primitive_->value.value = attr.release();
@@ -54,7 +52,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
}
#else
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); }

int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@@ -65,7 +62,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov());
auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;


+ 0
- 1
mindspore/lite/src/ops/apply_momentum.h View File

@@ -40,7 +40,6 @@ class ApplyMomentum : public PrimitiveC {
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
float GetGradientScale() const;
bool GetUseLocking() const;
bool GetUseNesterov() const;
};
} // namespace lite


+ 82
- 0
mindspore/lite/src/ops/assign.cc View File

@@ -0,0 +1,82 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/assign.h"
#include <memory>

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Assign::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Assign;
}
if (this->primitive_->value.type != schema::PrimitiveType_Assign) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::AssignT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Assign();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Assign return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAssign(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif

int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (2 != inputs.size()) {
MS_LOG(ERROR) << "Assign should have at least 5 input tensors";
return RET_ERROR;
}

if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}

if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 43
- 0
mindspore/lite/src/ops/assign.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
#define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Assign : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Assign, PrimitiveC);
Assign() = default;
explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Assign() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_

+ 2
- 2
mindspore/lite/src/ops/bn_grad.cc View File

@@ -45,8 +45,8 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
// FusedBatchNormGrad dows not get this attribute
if (prim.GetAttr("eps") != nullptr) {
attr->eps = GetValue<float>(prim.GetAttr("eps"));
if (prim.GetAttr("epsilon") != nullptr) {
attr->eps = GetValue<float>(prim.GetAttr("epsilon"));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {


+ 4
- 1
mindspore/lite/src/ops/conv2d_grad_input.cc View File

@@ -15,7 +15,7 @@
*/

#include "src/ops/conv2d_grad_input.h"
#include "src/ops/group_conv2d_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -86,6 +86,9 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
return RET_ERROR;
}
attr->group = GetValue<int>(prim.GetAttr("group"));
if (attr->group > 1) {
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
}
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;


+ 24
- 0
mindspore/lite/src/ops/exp.cc View File

@@ -26,6 +26,30 @@ void Exp::SetShift(float shift) { this->primitive_->value.AsExp()->shift = shift
float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; }
float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; }
float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; }

int Exp::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Exp;
}
if (this->primitive_->value.type != schema::PrimitiveType_Exp) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::ExpT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}

#else

int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {


+ 1
- 0
mindspore/lite/src/ops/exp.h View File

@@ -33,6 +33,7 @@ class Exp : public PrimitiveC {
void SetBase(float base);
void SetShift(float shift);
void SetScale(float scale);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Exp() = default;



+ 172
- 0
mindspore/lite/src/ops/group_conv2d_grad_input.cc View File

@@ -0,0 +1,172 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/group_conv2d_grad_input.h"

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; }
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; }
int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; }
int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; }
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; }
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; }
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; }
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; }
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; }
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; }
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; }
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; }
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; }
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; }
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; }
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; }
int GroupConv2DGradInput::GetActivationType() const {
return this->primitive_->value.AsGroupConv2DGradInput()->activationType;
}

void GroupConv2DGradInput::SetFormat(int format) {
this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format;
}
void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; }
void GroupConv2DGradInput::SetChannelIn(int channel_in) {
this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in;
}
void GroupConv2DGradInput::SetChannelOut(int channel_out) {
this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out;
}
void GroupConv2DGradInput::SetKernelW(int kernel_w) {
this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w;
}
void GroupConv2DGradInput::SetKernelH(int kernel_h) {
this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h;
}
void GroupConv2DGradInput::SetStrideW(int stride_w) {
this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w;
}
void GroupConv2DGradInput::SetStrideH(int stride_h) {
this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h;
}
void GroupConv2DGradInput::SetPadMode(int pad_mode) {
this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode;
}
void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; }
void GroupConv2DGradInput::SetPadDown(int pad_down) {
this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down;
}
void GroupConv2DGradInput::SetPadLeft(int pad_left) {
this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left;
}
void GroupConv2DGradInput::SetPadRight(int pad_right) {
this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right;
}
void GroupConv2DGradInput::SetDilateW(int dilate_w) {
this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w;
}
void GroupConv2DGradInput::SetDilateH(int dilate_h) {
this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h;
}
void GroupConv2DGradInput::SetHasBias(bool has_bias) {
this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias;
}
void GroupConv2DGradInput::SetActivationType(int activation_type) {
this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
}
#else
int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_GroupConv2DGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGroupConv2DGradInput(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); }
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); }
int GroupConv2DGradInput::GetChannelIn() const {
return this->primitive_->value_as_GroupConv2DGradInput()->channelIn();
}
int GroupConv2DGradInput::GetChannelOut() const {
return this->primitive_->value_as_GroupConv2DGradInput()->channelOut();
}
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); }
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); }
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); }
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); }
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); }
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); }
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); }
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); }
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); }
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); }
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); }
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); }
int GroupConv2DGradInput::GetActivationType() const {
return this->primitive_->value_as_GroupConv2DGradInput()->activationType();
}

#endif

int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (3 != inputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs";
return RET_ERROR;
}
if (1 != outputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad input should have one output";
return RET_ERROR;
}

auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);

std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->MutableData());
int new_size = in->ElementsNum();
if (in0->GetFormat() == in->GetFormat()) {
for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]);
} else {
if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) {
output_shape.push_back(out_shape[0]);
output_shape.push_back(out_shape[2]);
output_shape.push_back(out_shape[3]);
output_shape.push_back(out_shape[1]);
} else {
MS_LOG(ERROR) << "Shape covnert is not supported";
return RET_ERROR;
}
}

auto *out = outputs.at(0);
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());

return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 79
- 0
mindspore/lite/src/ops/group_conv2d_grad_input.h View File

@@ -0,0 +1,79 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include <string>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class GroupConv2DGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC);
GroupConv2DGradInput() = default;
explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
GroupConv2DGradInput() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
int GetChannelOut() const;
int GetKernelW() const;
int GetKernelH() const;
int GetStrideW() const;
int GetStrideH() const;
int GetPadMode() const;
int GetPadUp() const;
int GetPadDown() const;
int GetPadLeft() const;
int GetPadRight() const;
int GetDilateW() const;
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_

+ 25
- 1
mindspore/lite/src/ops/neg.cc View File

@@ -18,7 +18,31 @@

namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
#ifdef PRIMITIVE_WRITEABLE
int Neg::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Neg;
}
if (this->primitive_->value.type != schema::PrimitiveType_Neg) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::NegT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}

#else
int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(primitive != nullptr);
MS_ASSERT(fbb != nullptr);


+ 1
- 0
mindspore/lite/src/ops/neg.h View File

@@ -31,6 +31,7 @@ class Neg : public ArithmeticSelf {
MS_DECLARE_PARENT(Neg, ArithmeticSelf);
Neg() = default;
explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Neg() = default;



+ 31
- 0
mindspore/lite/src/ops/one_hot.cc View File

@@ -23,6 +23,37 @@ int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; }

void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; }

int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_OneHot;
}
if (this->primitive_->value.type != schema::PrimitiveType_OneHot) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::OneHotT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = -1;
if (prim.GetAttr("axis") != nullptr) {
attr->axis = GetValue<int>(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else

int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); }


+ 1
- 1
mindspore/lite/src/ops/one_hot.h View File

@@ -32,7 +32,7 @@ class OneHot : public PrimitiveC {
OneHot() = default;
explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
OneHot() = default;



+ 36
- 3
mindspore/lite/src/ops/primitive_c.cc View File

@@ -144,6 +144,7 @@
#include "src/ops/pooling_grad.h"
#include "src/ops/conv2d_grad_filter.h"
#include "src/ops/conv2d_grad_input.h"
#include "src/ops/group_conv2d_grad_input.h"
#include "src/ops/power_grad.h"
#include "src/ops/softmax_cross_entropy.h"
#include "src/ops/bn_grad.h"
@@ -152,6 +153,8 @@
#include "src/ops/flatten_grad.h"
#include "src/ops/log_grad.h"
#include "src/ops/sgd.h"
#include "src/ops/adam.h"
#include "src/ops/assign.h"
#endif

namespace mindspore {
@@ -367,7 +370,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect
std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
const auto &op_type = prim.name();
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") {
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") {
return NewPrimitiveC<Activation>(prim, inputs, quantType);
} else if (op_type == "AddN") {
return NewPrimitiveC<AddN>(prim, inputs, quantType);
@@ -413,6 +416,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Reduce>(prim, inputs, quantType);
} else if (op_type == "Reshape") {
return NewPrimitiveC<Reshape>(prim, inputs, quantType);
} else if (op_type == "Slice") {
return NewPrimitiveC<Slice>(prim, inputs, quantType);
} else if (op_type == "Squeeze") {
return NewPrimitiveC<Squeeze>(prim, inputs, quantType);
} else if (op_type == "TensorAdd") {
return NewPrimitiveC<Add>(prim, inputs, quantType);
} else if (op_type == "Transpose") {
@@ -421,6 +428,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Elu>(prim, inputs, quantType);
} else if (op_type == "Log") {
return NewPrimitiveC<Log>(prim, inputs, quantType);
} else if (op_type == "Exp") {
return NewPrimitiveC<Exp>(prim, inputs, quantType);
} else if (op_type == "Neg") {
return NewPrimitiveC<Neg>(prim, inputs, quantType);
} else if (op_type == "DeConv2D") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
} else if (op_type == "tuple_getitem") {
@@ -435,6 +446,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Maximum>(prim, inputs, quantType);
} else if (op_type == "Split") {
return NewPrimitiveC<Split>(prim, inputs, quantType);
} else if (op_type == "OneHot") {
return NewPrimitiveC<OneHot>(prim, inputs, quantType);

#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
@@ -445,7 +458,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType);
} else if (op_type == "Depend") {
return NewPrimitiveC<Depend>(prim, inputs, quantType);
} else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) {
} else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" ||
op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) {
return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType);
} else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) {
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
@@ -465,6 +479,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
} else if (op_type == "SGD") {
return NewPrimitiveC<Sgd>(prim, inputs, quantType);
} else if (op_type == "Adam") {
return NewPrimitiveC<Adam>(prim, inputs, quantType);
} else if (op_type == "Assign") {
return NewPrimitiveC<Assign>(prim, inputs, quantType);
#else
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
@@ -686,6 +704,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new Dropout(primitive);
case schema::PrimitiveType_Neg:
return new Neg(primitive);
case schema::PrimitiveType_RealDiv:
return new RealDiv(primitive);
case schema::PrimitiveType_LshProjection:
return new LshProjection(primitive);
case schema::PrimitiveType_HashtableLookup:
@@ -710,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new Conv2DGradFilter(primitive);
case schema::PrimitiveType_Conv2DGradInput:
return new Conv2DGradInput(primitive);
case schema::PrimitiveType_GroupConv2DGradInput:
return new GroupConv2DGradInput(primitive);
case schema::PrimitiveType_BiasGrad:
return new BiasGrad(primitive);
case schema::PrimitiveType_ApplyMomentum:
@@ -738,8 +760,11 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new LogGrad(primitive);
case schema::PrimitiveType_Sgd:
return new Sgd(primitive);
case schema::PrimitiveType_Adam:
return new Adam(primitive);
case schema::PrimitiveType_Assign:
return new Assign(primitive);
#endif

default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);
break;
@@ -958,6 +983,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<DetectionPostProcess>(primitive);
case schema::PrimitiveType_Dropout:
return NewPrimitiveC<Dropout>(primitive);
case schema::PrimitiveType_RealDiv:
return NewPrimitiveC<RealDiv>(primitive);
case schema::PrimitiveType_LshProjection:
return NewPrimitiveC<LshProjection>(primitive);
case schema::PrimitiveType_HashtableLookup:
@@ -982,6 +1009,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<Conv2DGradFilter>(primitive);
case schema::PrimitiveType_Conv2DGradInput:
return NewPrimitiveC<Conv2DGradInput>(primitive);
case schema::PrimitiveType_GroupConv2DGradInput:
return NewPrimitiveC<GroupConv2DGradInput>(primitive);
case schema::PrimitiveType_BiasGrad:
return NewPrimitiveC<BiasGrad>(primitive);
case schema::PrimitiveType_ApplyMomentum:
@@ -1004,6 +1033,10 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<LogGrad>(primitive);
case schema::PrimitiveType_Sgd:
return NewPrimitiveC<Sgd>(primitive);
case schema::PrimitiveType_Adam:
return NewPrimitiveC<Adam>(primitive);
case schema::PrimitiveType_Assign:
return NewPrimitiveC<Assign>(primitive);
#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);


+ 6
- 4
mindspore/lite/src/ops/primitive_c.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#ifndef MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_
#define MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_
#include <string>
#include <set>
#include <vector>
@@ -48,7 +48,9 @@ constexpr int kAnfPopulaterTwo = 2;
constexpr int kAnfPopulaterThree = 3;
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID}};
{"Sigmoid", schema::ActivationType_SIGMOID},
{"HSwish", schema::ActivationType_HSWISH},
{"HSigmoid", schema::ActivationType_HSIGMOID}};
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
@@ -213,4 +215,4 @@ class PrimitiveC {
#endif
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#endif // MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_

+ 8
- 1
mindspore/lite/src/ops/real_div.cc View File

@@ -44,7 +44,14 @@ int RealDiv::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
}

#else

int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateRank(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RealDiv, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

+ 1
- 4
mindspore/lite/src/ops/real_div.h View File

@@ -36,10 +36,7 @@ class RealDiv : public Arithmetic {
#else
RealDiv() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite


+ 63
- 1
mindspore/lite/src/ops/slice.cc View File

@@ -35,6 +35,66 @@ void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format =
void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; }
void Slice::SetSize(const std::vector<int> &size) { this->primitive_->value.AsSlice()->size = size; }

int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Slice;
}
if (this->primitive_->value.type != schema::PrimitiveType_Slice) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SliceT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (inputs.size() >= kAnfPopulaterThree) {
auto beginNode = inputs[kAnfPopulaterOne];
MS_ASSERT(beginNode != nullptr);
if (beginNode->isa<ValueNode>()) {
auto valueNode = beginNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto value = valueNode->value();
MS_ASSERT(value != nullptr);
if (value->isa<ValueTuple>()) {
auto valTuplPtr = dyn_cast<ValueTuple>(value);
MS_ASSERT(valTuplPtr != nullptr);
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]);
MS_ASSERT(elem != nullptr);
attr->begin.emplace_back(elem->value());
}
}
}
auto sizeNode = inputs[kAnfPopulaterTwo];
MS_ASSERT(sizeNode != nullptr);
if (sizeNode->isa<ValueNode>()) {
auto valueNode = sizeNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto value = valueNode->value();
MS_ASSERT(value != nullptr);
if (value->isa<ValueTuple>()) {
auto valTuplPtr = dyn_cast<ValueTuple>(value);
MS_ASSERT(valTuplPtr != nullptr);
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]);
MS_ASSERT(elem != nullptr);
attr->size.emplace_back(elem->value());
}
}
}
}
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else

int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); }
@@ -46,10 +106,12 @@ std::vector<int> Slice::GetSize() const {
auto fb_vector = this->primitive_->value_as_Slice()->size();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}

std::vector<int> Slice::GetAxes() const {
auto fb_vector = this->primitive_->value_as_Slice()->axes();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}

int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@@ -90,7 +152,7 @@ std::vector<int> Slice::GetPostProcessBegin() const { return this->begin; }
std::vector<int> Slice::GetPostProcessSize() const { return this->size; }
int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) {
if (inputs.size() < kSliceInputNum || outputs.size() != kSliceOutputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size();
return RET_PARAM_INVALID;
}


+ 4
- 3
mindspore/lite/src/ops/slice.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_
#define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_
#ifndef MINDSPORE_LITE_SRC_OPS_SLICE_H_
#define MINDSPORE_LITE_SRC_OPS_SLICE_H_

#include <vector>
#include <set>
@@ -35,6 +35,7 @@ class Slice : public PrimitiveC {
void SetFormat(int format);
void SetBegin(const std::vector<int> &begin);
void SetSize(const std::vector<int> &size);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Slice() = default;

@@ -56,4 +57,4 @@ class Slice : public PrimitiveC {
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_SLICE_H_
#endif // MINDSPORE_LITE_SRC_OPS_SLICE_H_

+ 29
- 0
mindspore/lite/src/ops/squeeze.cc View File

@@ -23,6 +23,35 @@ std::vector<int> Squeeze::GetAxis() const { return this->primitive_->value.AsSqu

void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsSqueeze()->axis = axis; }

int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Squeeze;
}
if (this->primitive_->value.type != schema::PrimitiveType_Squeeze) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SqueezeT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}

#else

std::vector<int> Squeeze::GetAxis() const {


+ 4
- 3
mindspore/lite/src/ops/squeeze.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_
#define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_
#ifndef MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_
#define MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_

#include <vector>
#include <set>
@@ -33,6 +33,7 @@ class Squeeze : public PrimitiveC {
Squeeze() = default;
explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

#else
Squeeze() = default;
@@ -45,4 +46,4 @@ class Squeeze : public PrimitiveC {
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_
#endif // MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_

+ 34
- 9
mindspore/lite/src/ops/tile.cc View File

@@ -24,7 +24,7 @@ std::vector<int> Tile::GetMultiples() const { return this->primitive_->value.AsT

void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive_->value.AsTile()->multiples = multiples; }

std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->multiples; }
std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->dims; }

void Tile::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsTile()->dims = dims; }

@@ -42,11 +42,32 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::TileT();
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TileT();

if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (inputs.size() == kAnfPopulaterTwo) {
auto inputNode = inputs[kAnfPopulaterOne];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto value = valueNode->value();
MS_ASSERT(value != nullptr);
if (value->isa<ValueTuple>()) {
auto valTuplPtr = dyn_cast<ValueTuple>(value);
MS_ASSERT(valTuplPtr != nullptr);
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]);
MS_ASSERT(elem != nullptr);
attr->multiples.emplace_back(elem->value());
}
}
}
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
@@ -103,15 +124,19 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output

MS_ASSERT(tile_prim != nullptr);
std::vector<int> out_shape;
std::vector<int> multiples;
for (size_t i = 0; i < GetMultiples().size(); ++i) {
multiples.push_back(GetMultiples()[i]);
std::vector<int> multiples = GetMultiples();
const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size();

size_t i = 0;
for (; i < delta_dims; ++i) {
int tmp = input->shape()[i];
out_shape.push_back(tmp);
}
for (size_t i = 0; i < input->shape().size(); ++i) {
int tmp = input->shape()[i] * multiples[i];
for (; i < in_dims; ++i) {
int tmp = input->shape()[i] * (multiples[i - delta_dims]);
out_shape.push_back(tmp);
}

output->set_shape(out_shape);
return RET_OK;
}


+ 2
- 0
mindspore/lite/src/populate_parameter.cc View File

@@ -43,6 +43,7 @@
#include "src/ops/add.h"
#include "src/ops/sub.h"
#include "src/ops/div.h"
#include "src/ops/real_div.h"
#include "src/ops/bias_add.h"
#include "src/ops/expand_dims.h"
#include "src/ops/full_connection.h"
@@ -1680,6 +1681,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_RealDiv] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_LogicalAnd] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_LogicalOr] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Equal] = PopulateArithmetic;


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc View File

@@ -24,6 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::ActivationType_HSIGMOID;
using mindspore::schema::ActivationType_HSWISH;
using mindspore::schema::ActivationType_LEAKY_RELU;
using mindspore::schema::ActivationType_RELU;
@@ -57,6 +58,8 @@ int ActivationCPUKernel::DoActivation(int task_id) {
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSWISH) {
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSIGMOID) {
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HARD_TANH) {
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
} else {


+ 23
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc View File

@@ -60,14 +60,34 @@ int AddNCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
elements_num_ = in_tensors_[0]->ElementsNum();
elements_num_ = out_tensors_[0]->ElementsNum();
auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto input1_data = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
if (static_cast<int>(elements_num_) < op_parameter_->thread_num_) {
ElementAdd(input0_data, input1_data, output_data, elements_num_);
if (in_tensors_[0]->shape() == in_tensors_[1]->shape()) {
ElementAdd(input0_data, input1_data, output_data, elements_num_);
} else {
ArithmeticParameter param;
param.in_elements_num0_ = in_tensors_[0]->ElementsNum();
param.in_elements_num1_ = in_tensors_[1]->ElementsNum();
param.out_elements_num_ = out_tensors_[0]->ElementsNum();
param.broadcasting_ = true;
ElementOptAdd(input0_data, input1_data, output_data, elements_num_, &param);
}

for (size_t i = 2; i < in_tensors_.size(); ++i) {
ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_);
if (in_tensors_[i]->shape() == out_tensors_[0]->shape()) {
ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_);
} else {
ArithmeticParameter param;
param.in_elements_num0_ = in_tensors_[i]->ElementsNum();
param.in_elements_num1_ = out_tensors_[0]->ElementsNum();
param.out_elements_num_ = out_tensors_[0]->ElementsNum();
param.broadcasting_ = true;
ElementOptAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_,
&param);
}
}
return RET_OK;
}


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc View File

@@ -104,6 +104,7 @@ int ArithmeticCPUKernel::ReSize() {
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
@@ -304,6 +305,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator)


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h View File

@@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Maximum;
using mindspore::schema::PrimitiveType_Minimum;
using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_NotEqual;
using mindspore::schema::PrimitiveType_RealDiv;
using mindspore::schema::PrimitiveType_SquaredDifference;
using mindspore::schema::PrimitiveType_Sub;

@@ -99,6 +100,7 @@ class ArithmeticCPUKernel : public LiteKernel {
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;


+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc View File

@@ -48,7 +48,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
int length = in_tensors_.at(0)->ElementsNum();

int stride = UP_DIV(length, thread_count_);
int stride = UP_DIV(length, 1);
int count = MSMIN(stride, length - stride * task_id);

auto error_code = RET_OK;
@@ -63,8 +63,9 @@ int ActivationGradCPUKernel::DoActivation(int task_id) {
error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count,
output_addr + stride * task_id, param_act_grad_->alpha_);
} else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) {
// Sigmoid gets the input tensors in reverse order!
error_code =
SigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);
SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
error_code =
TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id);


+ 1
- 2
mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h View File

@@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel {
explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
: LiteKernel(param, inputs, outputs, ctx, primitive) {
param_act_grad_ = reinterpret_cast<ActivationParameter *>(param);
}
~ActivationGradCPUKernel() override = default;
@@ -38,7 +38,6 @@ class ActivationGradCPUKernel : public LiteKernel {
int DoActivation(int task_id);

private:
int thread_count_;
ActivationParameter *param_act_grad_;
};



+ 118
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc View File

@@ -0,0 +1,118 @@

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32_grad/adam.h"
#include <cmath>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Adam;

namespace mindspore::kernel {

int AdamCPUKernel::ReSize() { return RET_OK; }

int AdamCPUKernel::Execute(int task_id) {
auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto m = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
auto v = reinterpret_cast<float *>(in_tensors_[2]->MutableData());
auto beta1_power = reinterpret_cast<float *>(in_tensors_[3]->MutableData())[0];
auto beta2_power = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0];
auto learning_rate = reinterpret_cast<float *>(in_tensors_[5]->MutableData())[0];
auto beta1 = reinterpret_cast<float *>(in_tensors_[6]->MutableData())[0];
auto beta2 = reinterpret_cast<float *>(in_tensors_[7]->MutableData())[0];
auto eps = reinterpret_cast<float *>(in_tensors_[8]->MutableData())[0];
auto gradient = reinterpret_cast<float *>(in_tensors_[9]->MutableData());
size_t elem_num = in_tensors_[0]->ElementsNum();

if (adam_param_->use_nesterov_) { // Nadam
for (size_t i = 0; i < elem_num; ++i) {
m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1));
v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2));
auto g_hat = gradient[i] / (1 - beta1_power);
auto m_hat = m[i] / (1 - beta1_power);
auto v_hat = v[i] / (1 - beta2_power);
auto m_tag = (1.f - beta1) * g_hat + beta1 * m_hat;
weight[i] -= learning_rate * m_tag / (sqrtf(v_hat) + eps);
}
} else {
for (size_t i = 0; i < elem_num; ++i) {
m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1));
v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2));
auto m_hat = m[i] / (1 - beta1_power);
auto v_hat = v[i] / (1 - beta2_power);
weight[i] -= learning_rate * m_hat / (sqrtf(v_hat) + eps);
}
}
return RET_OK;
}

int AdamRun(void *cdata, int task_id) {
auto Adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata);
auto error_code = Adam_kernel->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Adam run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int AdamCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}

int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int AdamCPUKernel::Init() { return RET_OK; }

kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Adam);
auto *kernel = new (std::nothrow) AdamCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);

auto ret = kernel->Init();
if (0 != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}

return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adam, CpuAdamFp32KernelCreator)
} // namespace mindspore::kernel

+ 44
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_

#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32_grad/optimizer.h"

namespace mindspore::kernel {
class AdamCPUKernel : public LiteKernel {
public:
explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
adam_param_ = reinterpret_cast<AdamParameter *>(parameter);
}
~AdamCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
int Execute(int task_id);

private:
AdamParameter *adam_param_;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_

+ 1
- 8
mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc View File

@@ -79,14 +79,7 @@ int ApplyMomentumCPUKernel::Run() {
return RET_OK;
}

int ApplyMomentumCPUKernel::Init() {
// Only for test with uninitialized Data
size_t elem_num = in_tensors_[0]->ElementsNum();
auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0;

return RET_OK;
}
int ApplyMomentumCPUKernel::Init() { return RET_OK; }

kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,


+ 91
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc View File

@@ -0,0 +1,91 @@

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32_grad/assign.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Assign;

namespace mindspore::kernel {

int AssignCPUKernel::ReSize() { return RET_OK; }

int AssignCPUKernel::Execute(int task_id) {
auto x = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto y = reinterpret_cast<float *>(in_tensors_[1]->MutableData());
size_t size = in_tensors_[0]->Size();

memcpy(x, y, size);
return RET_OK;
}

int AssignRun(void *cdata, int task_id) {
auto Assign_kernel = reinterpret_cast<AssignCPUKernel *>(cdata);
auto error_code = Assign_kernel->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "assign run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int AssignCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}

int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}

int AssignCPUKernel::Init() { return RET_OK; }

kernel::LiteKernel *CpuAssignFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Assign);
auto *kernel = new (std::nothrow) AssignCPUKernel(opParameter, inputs, outputs, ctx, primitive);
MS_ASSERT(kernel != nullptr);

auto ret = kernel->Init();
if (0 != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}

return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assign, CpuAssignFp32KernelCreator)
} // namespace mindspore::kernel

+ 39
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_

#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp32_grad/optimizer.h"

namespace mindspore::kernel {
class AssignCPUKernel : public LiteKernel {
public:
explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~AssignCPUKernel() override {}
int Init() override;
int ReSize() override;
int Run() override;
int Execute(int task_id);
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_

+ 5
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc View File

@@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2DGradInput;
using mindspore::schema::PrimitiveType_GroupConv2DGradInput;

namespace mindspore::kernel {
int ConvolutionGradInputCPUKernel::Init() {
@@ -134,7 +135,8 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput ||
desc.type == schema::PrimitiveType_GroupConv2DGradInput);

auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
@@ -154,4 +156,6 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GroupConv2DGradInput, CpuConvGradInputFp32KernelCreator)

} // namespace mindspore::kernel

+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc View File

@@ -154,4 +154,6 @@ kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector<li
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator)
} // namespace mindspore::kernel

+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -178,5 +178,4 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator(
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSparseSoftmaxCrossEntropyFp32KernelCreator)
} // namespace mindspore::kernel

+ 3
- 0
mindspore/lite/src/train/train_model.cc View File

@@ -41,6 +41,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
}
model->buf = reinterpret_cast<char *>(malloc(size));
if (model->buf == nullptr) {
delete model;
MS_LOG(ERROR) << "new inner model buf fail!";
return nullptr;
}
@@ -48,6 +49,8 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
model->buf_size_ = size;
auto meta_graph = schema::GetMetaGraph(model->buf);
if (meta_graph == nullptr) {
delete model;
free(model->buf);
MS_LOG(ERROR) << "meta_graph is nullptr!";
return nullptr;
}


+ 64
- 1
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -24,6 +24,7 @@
#include "nnacl/fp32/activation.h"
#include "src/ops/conv2d_grad_filter.h"
#include "src/ops/conv2d_grad_input.h"
#include "src/ops/group_conv2d_grad_input.h"
#include "nnacl/conv_parameter.h"
#include "src/ops/power_grad.h"
#include "nnacl/power_parameter.h"
@@ -34,6 +35,7 @@
#include "src/ops/sgd.h"
#include "src/ops/bn_grad.h"
#include "nnacl/fp32_grad/batch_norm.h"
#include "src/ops/adam.h"

namespace mindspore::kernel {

@@ -69,12 +71,29 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p
reinterpret_cast<mindspore::lite::ApplyMomentum *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));

p->grad_scale_ = apply_momentum_primitive->GetGradientScale();
p->use_locking_ = apply_momentum_primitive->GetUseLocking();
p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov();

return reinterpret_cast<OpParameter *>(p);
}

OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter)));
if (p == nullptr) {
MS_LOG(ERROR) << "new AdamParameter failed.";
return nullptr;
}
p->op_parameter_.type_ = primitive->Type();

auto apply_momentum_primitive =
reinterpret_cast<mindspore::lite::Adam *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov();
return reinterpret_cast<OpParameter *>(p);
}

OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
@@ -264,6 +283,47 @@ OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::Primit
return reinterpret_cast<OpParameter *>(param);
}

OpParameter *PopulateGroupConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}

ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "new Param for conv grad filter failed.";
return nullptr;
}
param->op_parameter_.type_ = primitive->Type();

auto convg_primitive =
reinterpret_cast<mindspore::lite::GroupConv2DGradInput *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
param->kernel_h_ = convg_primitive->GetKernelH();
param->kernel_w_ = convg_primitive->GetKernelW();
param->stride_h_ = convg_primitive->GetStrideH();
param->stride_w_ = convg_primitive->GetStrideW();
param->dilation_h_ = convg_primitive->GetDilateH();
param->dilation_w_ = convg_primitive->GetDilateW();
param->pad_u_ = convg_primitive->GetPadUp();
param->pad_d_ = convg_primitive->GetPadDown();
param->pad_l_ = convg_primitive->GetPadLeft();
param->pad_r_ = convg_primitive->GetPadRight();
param->group_ = convg_primitive->GetGroup();
param->act_type_ = ActType_No;
switch (convg_primitive->GetActivationType()) {
case schema::ActivationType_RELU:
param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
param->act_type_ = ActType_Relu6;
break;
default:
break;
}

return reinterpret_cast<OpParameter *>(param);
}

OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
@@ -327,10 +387,13 @@ void PopulateTrainParameters() {
ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_GroupConv2DGradInput, PopulateGroupConvolutionGradInputParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Adam, PopulateAdamParameter);
ppr->AddPopulateParameterFunc(schema::PrimitiveType_Assign, DefaultPopulateParameter);
}

} // namespace mindspore::kernel

+ 34
- 10
mindspore/lite/src/train/train_session.cc View File

@@ -104,23 +104,15 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session:
for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast<lite::Tensor *>(ms_tensor)));
if (train_mode_) return lite::LiteSession::RunGraph(before, after);

// object is expected to run only inference part of graph
// prepare a list of kernels till the loss function -- temporary solution
std::vector<kernel::LiteKernel *> inference_kernels;
for (auto kernel : this->kernels_) {
if (IsLossKernel(kernel)) break;
inference_kernels.push_back(kernel);
}

if (this->context_ == nullptr) {
MS_LOG(ERROR) << "context is null";
return lite::RET_NULL_PTR;
}
lite::Executor executor;
if (before == nullptr && after == nullptr) {
return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get());
return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get());
} else {
return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get(), before,
return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get(), before,
after);
}
}
@@ -173,6 +165,38 @@ void TrainSession::Eval() {
}
}
}
if (inference_kernels_.size() == 0) {
BuildInferenceKernelsMap();
}
}

void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) {
if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not in vector
v->push_back(kernel);
for (auto in_node : kernel->in_kernels()) {
BuildInferenceKernelsRecursive(in_node, v);
}
}
}

void TrainSession::BuildInferenceKernelsMap() {
std::vector<kernel::LiteKernel *> req_kernels;
for (auto kernel : this->kernels_) {
if (IsLossKernel(kernel)) { // For each loss in the system add backward tree
for (auto in_node : kernel->in_kernels()) {
BuildInferenceKernelsRecursive(in_node, &req_kernels);
}
}
}
inference_kernels_.clear();
for (auto kernel : this->kernels_) {
if (std::find(req_kernels.begin(), req_kernels.end(), kernel) != req_kernels.end()) {
inference_kernels_.push_back(kernel);
}
}
if (inference_kernels_.size() == 0) {
inference_kernels_ = this->kernels_;
}
}

bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) {


+ 5
- 1
mindspore/lite/src/train/train_session.h View File

@@ -82,12 +82,16 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::

protected:
void AllocWorkSpace();
bool IsLossKernel(kernel::LiteKernel *kernel);
virtual std::vector<CreatorOp> ReplaceOps();
virtual void RestoreOps(const std::vector<CreatorOp> &restore);
bool IsLossKernel(kernel::LiteKernel *kernel);
virtual void BuildInferenceKernelsMap();
virtual void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels);

TrainModel *model_ = nullptr;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
std::vector<kernel::LiteKernel *> inference_kernels_;
};
} // namespace lite
} // namespace mindspore


+ 6
- 6
mindspore/lite/test/common/common_test.h View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TESTS_UT_COMMON_UT_COMMON_H_
#define TESTS_UT_COMMON_UT_COMMON_H_
#ifndef MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_
#define MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_

#include <cmath>
#include <fstream>
@@ -37,11 +37,11 @@ class CommonTest : public testing::Test {
void PrintData(std::string name, T *output_data, int size) {
std::cout << "The " << name << " is as follows:" << std::endl;
if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) {
for (size_t i = 0; i < std::min(size, 100); i++) {
for (int i = 0; i < std::min(size, 100); i++) {
std::cout << static_cast<int>(output_data[i]) << " ";
}
} else {
for (size_t i = 0; i < std::min(size, 100); i++) {
for (int i = 0; i < std::min(size, 100); i++) {
std::cout << output_data[i] << " ";
}
}
@@ -58,7 +58,7 @@ class CommonTest : public testing::Test {

void CompareOutputInt8(int8_t *output_data, int8_t *correct_data, int size, float err_percent) {
int bias_count = 0;
for (size_t i = 0; i < size; i++) {
for (int i = 0; i < size; i++) {
int8_t diff = abs(output_data[i] - correct_data[i]);
ASSERT_LE(diff, 1);
if (diff == 1) {
@@ -88,4 +88,4 @@ class CommonTest : public testing::Test {
}
};
} // namespace mindspore
#endif // TESTS_UT_COMMON_UT_COMMON_H_
#endif // MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_

+ 26
- 14
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc View File

@@ -250,7 +250,7 @@ TEST_F(NetworkTest, tuning_layer) {
auto label = std::make_unique<schema::TensorT>();
label->nodeType = schema::NodeType::NodeType_ValueNode;
label->format = schema::Format_NHWC;
label->dataType = TypeId::kNumberTypeInt32;
label->dataType = TypeId::kNumberTypeFloat32;
label->dims = {BATCH_SIZE * NUM_CLASSES};
label->offset = -1;
meta_graph->allTensors.emplace_back(std::move(label));
@@ -386,8 +386,10 @@ TEST_F(NetworkTest, tuning_layer) {
auto labelTensor = inputs.at(1);
ASSERT_NE(nullptr, labelTensor);
ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum());
auto labels = reinterpret_cast<int *>(labelTensor->MutableData());
for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES;

auto labels = reinterpret_cast<float *>(labelTensor->MutableData());
std::fill(labels, labels + labelTensor->ElementsNum(), 0.f);
for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0;

ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
@@ -576,12 +578,12 @@ TEST_F(NetworkTest, lenetnet) {
delete context;
ASSERT_EQ(res, 0);
}
#if 0
TEST_F(NetworkTest, retina_net) {
char *buf = nullptr;
size_t net_size = 0;

std::string net = "./test_data/nets/retinaface1009.ms";
std::string net = "./test_data/nets/retinaface1.ms";
ReadFile(net.c_str(), &net_size, &buf);
// auto model = lite::TrainModel::Import(buf, net_size);
auto model = lite::Model::Import(buf, net_size);
@@ -598,26 +600,36 @@ TEST_F(NetworkTest, retina_net) {
ASSERT_EQ(lite::RET_OK, ret);
// session->Eval();

std::string in = "./test_data/nets/retinaface_input.f32";
std::string in = "./test_data/nets/test1.hwc_normalized_f32";
std::cout << "----- Output 0 -----" << std::endl;
std::string out = "./test_data/nets/retinaface_out_0.f32";
std::string out = "./test_data/nets/test1_loc.f32";
int final_res = 0;
auto res = runNet(session, in, out, "448", true);
ASSERT_EQ(res, 0);
// ASSERT_EQ(res, 0);
if (res != 0) {
final_res = res;
}

std::cout << "----- Output 1 -----" << std::endl;
out = "./test_data/nets/retinaface_out_1.f32";
out = "./test_data/nets/test1_conf.f32";
res = runNet(session, in, out, "435", true);
ASSERT_EQ(res, 0);

// ASSERT_EQ(res, 0);
if (res != 0) {
final_res |= res;
}
std::cout << "----- Output 2 -----" << std::endl;
out = "./test_data/nets/retinaface_out_2.f32";
out = "./test_data/nets/test1_landms.f32";
res = runNet(session, in, out, "421", true);
ASSERT_EQ(res, 0);
if (res != 0) {
final_res |= res;
}

ASSERT_EQ(final_res, 0);

delete session;
delete context;
}
#endif
TEST_F(NetworkTest, mobileface_net) {
char *buf = nullptr;
size_t net_size = 0;


+ 4
- 3
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc View File

@@ -41,10 +41,11 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {

std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin";
auto ll_labels = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size));
auto labels = new int[6];
for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]);
auto labels = new float[6 * 4];
std::fill(labels, labels + 6 * 4, 0.f);
for (int i = 0; i < 6; i++) labels[i * 4 + ll_labels[i]] = 1.0;

std::vector<int> dim_l({6});
std::vector<int> dim_l({6, 4});
lite::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l);
l_tensor.SetData(labels);



BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effnetb0_fwd_fuse.ms View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface1.ms View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1.hwc_normalized_f32 View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_conf.f32 View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_landms.f32 View File


BIN
mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_loc.f32 View File


+ 20
- 1
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -274,9 +274,24 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s
auto input_cnode = utils::cast<CNodePtr>(input_anode);

if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
#ifndef SUPPORT_TRAIN
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
}
#else
bool found = false;
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
found = true;
}

if (found == false) {
auto input_index_key = input_name + "_o:" + std::to_string(0);
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
}
}
#endif
} else {
auto inputs = input_cnode->inputs();
if (inputs.size() != 3) {
@@ -369,6 +384,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
auto typePtr = abstractTensor->element()->GetTypeTrack();
paramTensor->dataType = typePtr->type_id();
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
#ifdef SUPPORT_TRAIN
if (paramTensor->dims.size() == 0) paramTensor->dims = {1};
#endif
paramTensor->nodeType = schema::NodeType::NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();
paramTensor->data.resize(data->Size());
@@ -505,7 +523,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
node_id_map_[name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D))
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
break;
#else
if (tuple->size() == 1) {


+ 33
- 4
mindspore/lite/tools/anf_importer/import_from_protobuf.cc View File

@@ -678,6 +678,13 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
#ifdef SUPPORT_TRAIN
auto ret_node = outputFuncGraph->get_return();
if (ret_node) {
ret_node->add_input(cnode_ptr);
return true;
}
#endif
const onnx::ValueInfoProto &output_node = importProto.output(0);
const onnx::TypeProto &output_typeproto = output_node.type();
int output_type = output_typeproto.tensor_type().elem_type();
@@ -687,7 +694,6 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);

inputs.clear();
auto primReturn = std::make_unique<schema::PrimitiveT>();
MS_ASSERT(primReturn != nullptr);
@@ -717,6 +723,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
}
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
CNodePtr last_cnode_ptr = nullptr;
int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("MINDIR");
for (int i = 0; i < importProto.node_size(); ++i) {
@@ -734,13 +741,35 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
status = (status == RET_OK ? RET_NULL_PTR : status);
}

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode_ptr->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
status = RET_ERROR;
}

#ifdef SUPPORT_TRAIN
if (primitive_c->Type() == schema::PrimitiveType_MakeTuple) {
last_cnode_ptr = cnode_ptr;
if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) {
MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed";
status = RET_ERROR;
}
}
#endif
}
if (status != RET_OK) {
return status;
}
if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) {
MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed";
status = RET_ERROR;
#ifdef SUPPORT_TRAIN
if (last_cnode_ptr != cnode_ptr) {
#else
{
#endif
if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) {
MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed";
status = RET_ERROR;
}
}
return status;
}


+ 2
- 0
mindspore/lite/tools/common/node_util.cc View File

@@ -28,12 +28,14 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_Conv2DGradFilter,
schema::PrimitiveType_Conv2DGradInput,
schema::PrimitiveType_GroupConv2DGradInput,
schema::PrimitiveType_PoolingGrad,
schema::PrimitiveType_BiasGrad,
schema::PrimitiveType_BNGrad,
schema::PrimitiveType_ActivationGrad,
schema::PrimitiveType_ApplyMomentum,
schema::PrimitiveType_Sgd,
schema::PrimitiveType_Adam,
#endif
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DeConv2D,


+ 1
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc View File

@@ -161,6 +161,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
int idx = 0;
if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3;
if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1;
if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9;
iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";


+ 3
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -183,7 +183,9 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
auto origin_begin = attr->begin;
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
auto origin_end = attr->axes;
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
if (origin_end.size() >= 4) {
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
}
auto origin_stride = attr->size;
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
}


+ 10
- 0
mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc View File

@@ -122,6 +122,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
param_value->set_format(schema::Format::Format_CKHW);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
#ifdef SUPPORT_TRAIN
} else if (op_type == schema::PrimitiveType_Conv2DGradInput) {
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) {
param_value->set_format(schema::Format::Format_CKHW);
#endif
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
@@ -178,6 +184,10 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
auto conv_cnode = node->cast<CNodePtr>();
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
#ifdef SUPPORT_TRAIN
((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) &&
((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) &&
#endif
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}


+ 12
- 15
mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc View File

@@ -18,27 +18,21 @@
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_WeightQuant;

namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
void WeightFormatTransformPass::SetQuantType(QuantType type) {
this->quant_type = type;
}
void WeightFormatTransformPass::SetFmkType(FmkType type) {
this->fmk_type = type;
}
void WeightFormatTransformPass::SetDstFormat(schema::Format format) {
this->dst_format = format;
}
void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; }
void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; }
void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; }
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
@@ -48,6 +42,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
}
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
#ifdef SUPPORT_TRAIN
&& type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput
#endif
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}
@@ -60,8 +57,8 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
MS_LOG(ERROR) << "weight node must param value";
return false;
}
MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32
|| weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 ||
weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
lite::STATUS status;
schema::Format weight_dst_format = schema::Format::Format_KHWC;
if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {


Loading…
Cancel
Save