| @@ -117,6 +117,14 @@ int HSwish(const float *src, int length, float *dst) { | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int HSigmoid(const float *src, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6); | |||||
| dst[i] = relu6 / 6; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { | int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { | ||||
| if (max_val <= min_val) { | if (max_val <= min_val) { | ||||
| return NNACL_ERR; | return NNACL_ERR; | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||||
| #define MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||||
| #include <math.h> | #include <math.h> | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| @@ -36,9 +36,10 @@ int Fp32Relu6(const float *src, int length, float *dst); | |||||
| int LRelu(const float *src, int length, float *dst, float alpha); | int LRelu(const float *src, int length, float *dst, float alpha); | ||||
| int Sigmoid(const float *src, int length, float *dst); | int Sigmoid(const float *src, int length, float *dst); | ||||
| int Tanh(const float *src, int length, float *dst); | int Tanh(const float *src, int length, float *dst); | ||||
| int HSigmoid(const float *src, int length, float *dst); | |||||
| int HSwish(const float *src, int length, float *dst); | int HSwish(const float *src, int length, float *dst); | ||||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||||
| @@ -31,14 +31,14 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par | |||||
| int i, j, k; | int i, j, k; | ||||
| for (i = tid; i < outer_size; i += thread_num) { | for (i = tid; i < outer_size; i += thread_num) { | ||||
| float *output_ptr = output + i * depth * inner_size; | float *output_ptr = output + i * depth * inner_size; | ||||
| for (k = 0; k < depth; k++) { | |||||
| for (j = 0; j < inner_size; j++) { | |||||
| for (k = 0; k < inner_size; k++) { | |||||
| int index = indices[i * inner_size + k]; | |||||
| for (j = 0; j < depth; j++) { | |||||
| *output_ptr = off_value; | *output_ptr = off_value; | ||||
| int index = indices[i * inner_size + j]; | |||||
| if (index >= depth) { | if (index >= depth) { | ||||
| return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; | return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; | ||||
| } | } | ||||
| if (index == k) { | |||||
| if (index == j) { | |||||
| *output_ptr = on_value; | *output_ptr = on_value; | ||||
| } | } | ||||
| output_ptr++; | output_ptr++; | ||||
| @@ -15,27 +15,52 @@ | |||||
| */ | */ | ||||
| #include "nnacl/fp32_grad/gemm.h" | #include "nnacl/fp32_grad/gemm.h" | ||||
| #include <string.h> | |||||
| static void gemm_not_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||||
| float *mat_c, int ldc) { | |||||
| const int block_size = 4; | |||||
| int block_mod = N % block_size; | |||||
| int block_c4 = N - block_mod; | |||||
| static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| int i, j, k; | int i, j, k; | ||||
| for (i = 0; i < M; ++i) { | for (i = 0; i < M; ++i) { | ||||
| for (k = 0; k < K; ++k) { | for (k = 0; k < K; ++k) { | ||||
| float a = alpha * mat_a[i * lda + k]; | float a = alpha * mat_a[i * lda + k]; | ||||
| for (j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] += a * mat_B[k * ldb + j]; | |||||
| for (j = 0; j < block_c4; j += block_size) { | |||||
| float *b = &mat_b[k * ldb + j]; | |||||
| float *c = &mat_c[i * ldc + j]; | |||||
| c[0] += a * b[0]; | |||||
| c[1] += a * b[1]; | |||||
| c[2] += a * b[2]; | |||||
| c[3] += a * b[3]; | |||||
| } | |||||
| for (; j < N; ++j) { | |||||
| mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| static void gemm_not_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||||
| float *mat_c, int ldc) { | |||||
| const int block_size = 4; | |||||
| int block_mod = K % block_size; | |||||
| int block_c4 = K - block_mod; | |||||
| int i, j, k; | int i, j, k; | ||||
| for (i = 0; i < M; ++i) { | for (i = 0; i < M; ++i) { | ||||
| for (j = 0; j < N; ++j) { | for (j = 0; j < N; ++j) { | ||||
| float sum = 0; | float sum = 0; | ||||
| for (k = 0; k < K; ++k) { | |||||
| for (k = 0; k < block_c4; k += block_size) { | |||||
| float *a = &mat_a[i * lda + k]; | |||||
| float *b = &mat_b[j * ldb + k]; | |||||
| sum += alpha * a[0] * b[0]; | |||||
| sum += alpha * a[1] * b[1]; | |||||
| sum += alpha * a[2] * b[2]; | |||||
| sum += alpha * a[3] * b[3]; | |||||
| } | |||||
| for (; k < K; ++k) { | |||||
| sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; | sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; | ||||
| } | } | ||||
| mat_c[i * ldc + j] += sum; | mat_c[i * ldc + j] += sum; | ||||
| @@ -43,23 +68,85 @@ static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, flo | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| static void gemm_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||||
| float *mat_c, int ldc) { | |||||
| const int block_size = 4; | |||||
| int block_mod = N % block_size; | |||||
| int block_c4 = N - block_mod; | |||||
| int i, j, k; | int i, j, k; | ||||
| for (i = 0; i < M; ++i) { | for (i = 0; i < M; ++i) { | ||||
| for (k = 0; k < K; ++k) { | for (k = 0; k < K; ++k) { | ||||
| float a = alpha * mat_a[k * lda + i]; | float a = alpha * mat_a[k * lda + i]; | ||||
| for (j = 0; j < N; ++j) { | |||||
| for (j = 0; j < block_c4; j += block_size) { | |||||
| float *b = &mat_b[k * ldb + j]; | |||||
| float *c = &mat_c[i * ldc + j]; | |||||
| c[0] += a * b[0]; | |||||
| c[1] += a * b[1]; | |||||
| c[2] += a * b[2]; | |||||
| c[3] += a * b[3]; | |||||
| } | |||||
| for (; j < N; ++j) { | |||||
| mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||||
| int ldc) { | |||||
| static void gemm_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||||
| float *mat_c, int ldc) { | |||||
| int i, j, k; | int i, j, k; | ||||
| for (i = 0; i < M; ++i) { | |||||
| const int block_size = 4; | |||||
| int k_block_mod = K % block_size; | |||||
| int k_block_c4 = K - k_block_mod; | |||||
| int m_block_mod = M % block_size; | |||||
| int m_block_c4 = M - m_block_mod; | |||||
| for (i = 0; i < m_block_c4; i += block_size) { | |||||
| for (j = 0; j < N; ++j) { | |||||
| float sum0 = 0; | |||||
| float sum1 = 0; | |||||
| float sum2 = 0; | |||||
| float sum3 = 0; | |||||
| for (k = 0; k < k_block_c4; k += block_size) { | |||||
| float *b = &mat_b[j * ldb + k]; | |||||
| sum0 += alpha * mat_a[i + k * lda] * b[0]; | |||||
| sum0 += alpha * mat_a[i + (k + 1) * lda] * b[1]; | |||||
| sum0 += alpha * mat_a[i + (k + 2) * lda] * b[2]; | |||||
| sum0 += alpha * mat_a[i + (k + 3) * lda] * b[3]; | |||||
| sum1 += alpha * mat_a[i + 1 + k * lda] * b[0]; | |||||
| sum1 += alpha * mat_a[i + 1 + (k + 1) * lda] * b[1]; | |||||
| sum1 += alpha * mat_a[i + 1 + (k + 2) * lda] * b[2]; | |||||
| sum1 += alpha * mat_a[i + 1 + (k + 3) * lda] * b[3]; | |||||
| sum2 += alpha * mat_a[i + 2 + k * lda] * b[0]; | |||||
| sum2 += alpha * mat_a[i + 2 + (k + 1) * lda] * b[1]; | |||||
| sum2 += alpha * mat_a[i + 2 + (k + 2) * lda] * b[2]; | |||||
| sum2 += alpha * mat_a[i + 2 + (k + 3) * lda] * b[3]; | |||||
| sum3 += alpha * mat_a[i + 3 + k * lda] * b[0]; | |||||
| sum3 += alpha * mat_a[i + 3 + (k + 1) * lda] * b[1]; | |||||
| sum3 += alpha * mat_a[i + 3 + (k + 2) * lda] * b[2]; | |||||
| sum3 += alpha * mat_a[i + 3 + (k + 3) * lda] * b[3]; | |||||
| } | |||||
| for (; k < K; ++k) { | |||||
| float *b = &mat_b[j * ldb + k]; | |||||
| sum0 += alpha * mat_a[i + (k * lda)] * b[0]; | |||||
| sum1 += alpha * mat_a[i + 1 + (k * lda)] * b[0]; | |||||
| sum2 += alpha * mat_a[i + 2 + (k * lda)] * b[0]; | |||||
| sum3 += alpha * mat_a[i + 3 + (k * lda)] * b[0]; | |||||
| } | |||||
| mat_c[i * ldc + j] += sum0; | |||||
| mat_c[(i + 1) * ldc + j] += sum1; | |||||
| mat_c[(i + 2) * ldc + j] += sum2; | |||||
| mat_c[(i + 3) * ldc + j] += sum3; | |||||
| } | |||||
| } | |||||
| // no more block of 4x4 | |||||
| for (; i < M; ++i) { | |||||
| for (j = 0; j < N; ++j) { | for (j = 0; j < N; ++j) { | ||||
| float sum = 0; | float sum = 0; | ||||
| for (k = 0; k < K; ++k) { | for (k = 0; k < K; ++k) { | ||||
| @@ -74,34 +161,37 @@ static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, flo | |||||
| // M - number of rows of matrix a | // M - number of rows of matrix a | ||||
| // N - number of cols of matrix b | // N - number of cols of matrix b | ||||
| // K - number of cols of matrix a | // K - number of cols of matrix a | ||||
| // lda - fast dim of matrix a | |||||
| // ldb - fast dim of matrix b | |||||
| // ldc - fast dim of matrix c | |||||
| void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | ||||
| int ldb, float beta, float *mat_c, int ldc) { | int ldb, float beta, float *mat_c, int ldc) { | ||||
| if (beta >= 0.f && beta <= 0.f) { | if (beta >= 0.f && beta <= 0.f) { | ||||
| for (int i = 0; i < M; ++i) { | |||||
| for (int j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] = 0; | |||||
| } | |||||
| } | |||||
| memset(mat_c, 0, M * N * sizeof(float)); | |||||
| } else if (beta < 1.f || beta > 1.f) { | } else if (beta < 1.f || beta > 1.f) { | ||||
| for (int i = 0; i < M; ++i) { | |||||
| for (int j = 0; j < N; ++j) { | |||||
| mat_c[i * ldc + j] *= beta; | |||||
| } | |||||
| const int block_size = 4; | |||||
| const int size = M * N; | |||||
| int block_mod = size % block_size; | |||||
| int block_c4 = size - block_mod; | |||||
| int i; | |||||
| for (i = 0; i < block_c4; i += block_size) { | |||||
| float *c = &mat_c[i]; | |||||
| c[0] *= beta; | |||||
| c[1] *= beta; | |||||
| c[2] *= beta; | |||||
| c[3] *= beta; | |||||
| } | } | ||||
| } | |||||
| int t; | |||||
| for (t = 0; t < M; ++t) { | |||||
| if (!transpose_a && !transpose_b) { | |||||
| gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else if (transpose_a && !transpose_b) { | |||||
| gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else if (!transpose_a && transpose_b) { | |||||
| gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| } else { | |||||
| gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||||
| for (; i < size; ++i) { | |||||
| mat_c[i] *= beta; | |||||
| } | } | ||||
| } | } | ||||
| if (transpose_a && transpose_b) { | |||||
| gemm_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||||
| } else if (!transpose_a && !transpose_b) { | |||||
| gemm_not_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||||
| } else if (!transpose_a && transpose_b) { | |||||
| gemm_not_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||||
| } else { | |||||
| gemm_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||||
| } | |||||
| } | } | ||||
| @@ -21,7 +21,6 @@ | |||||
| typedef struct ApplyMomentumParameter { | typedef struct ApplyMomentumParameter { | ||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| bool use_locking_; | |||||
| bool use_nesterov_; | bool use_nesterov_; | ||||
| float grad_scale_; | float grad_scale_; | ||||
| } ApplyMomentumParameter; | } ApplyMomentumParameter; | ||||
| @@ -33,4 +32,9 @@ typedef struct SgdParameter { | |||||
| float weight_decay_; | float weight_decay_; | ||||
| } SgdParameter; | } SgdParameter; | ||||
| typedef struct AdamParameter { | |||||
| OpParameter op_parameter_; | |||||
| bool use_nesterov_; | |||||
| } AdamParameter; | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | ||||
| @@ -182,7 +182,7 @@ union PrimitiveType { | |||||
| Conv2DGradInput, | Conv2DGradInput, | ||||
| PoolingGrad, | PoolingGrad, | ||||
| BNGrad, | BNGrad, | ||||
| BNGradInput, | |||||
| Assign, | |||||
| ApplyMomentum, | ApplyMomentum, | ||||
| BiasGrad, | BiasGrad, | ||||
| SoftmaxCrossEntropy, | SoftmaxCrossEntropy, | ||||
| @@ -217,6 +217,8 @@ union PrimitiveType { | |||||
| FftReal, | FftReal, | ||||
| FftImag, | FftImag, | ||||
| Sgd, | Sgd, | ||||
| Adam, | |||||
| GroupConv2DGradInput, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -224,7 +224,29 @@ table Conv2DGradInput { | |||||
| dilateH: int; | dilateH: int; | ||||
| hasBias: bool = false; | hasBias: bool = false; | ||||
| activationType: ActivationType = 0; | activationType: ActivationType = 0; | ||||
| }table FusedBatchNorm { | |||||
| } | |||||
| table GroupConv2DGradInput { | |||||
| format: Format = 0; | |||||
| group: int; | |||||
| channelIn: int; | |||||
| channelOut: int; | |||||
| kernelW: int; | |||||
| kernelH: int; | |||||
| strideW: int; | |||||
| strideH: int; | |||||
| padMode: PadMode; | |||||
| padUp: int; | |||||
| padDown: int; | |||||
| padLeft: int; | |||||
| padRight: int; | |||||
| dilateW: int; | |||||
| dilateH: int; | |||||
| hasBias: bool = false; | |||||
| activationType: ActivationType = 0; | |||||
| } | |||||
| table FusedBatchNorm { | |||||
| epsilon: float = 0.00001; // eg. epsilon=0.001 | epsilon: float = 0.00001; // eg. epsilon=0.001 | ||||
| momentum: float = 0.9; | momentum: float = 0.9; | ||||
| spatial: int = 1; | spatial: int = 1; | ||||
| @@ -901,7 +923,6 @@ table TupleGetItem { | |||||
| table ApplyMomentum { | table ApplyMomentum { | ||||
| gradientScale: float; | gradientScale: float; | ||||
| useLocking: bool; | |||||
| useNesterov: bool; | useNesterov: bool; | ||||
| } | } | ||||
| @@ -911,6 +932,14 @@ table Sgd { | |||||
| useNesterov: bool; | useNesterov: bool; | ||||
| } | } | ||||
| table Adam { | |||||
| useNesterov: bool; | |||||
| } | |||||
| table Assign { | |||||
| } | |||||
| table Where{ | table Where{ | ||||
| condition: [bool]; | condition: [bool]; | ||||
| } | } | ||||
| @@ -50,6 +50,10 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (prim.name() == "ReLU6") { | } else if (prim.name() == "ReLU6") { | ||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| } else if (prim.name() == "HSwish") { | |||||
| attr->type = schema::ActivationType_HSWISH; | |||||
| } else if (prim.name() == "HSigmoid") { | |||||
| attr->type = schema::ActivationType_HSIGMOID; | |||||
| } | } | ||||
| this->primitive_->value.value = attr.release(); | this->primitive_->value.value = attr.release(); | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| @@ -43,8 +43,12 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP | |||||
| attr->type = schema::ActivationType_RELU; | attr->type = schema::ActivationType_RELU; | ||||
| } else if (prim.name() == "SigmoidGrad") { | } else if (prim.name() == "SigmoidGrad") { | ||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (prim.name() == "Relu6Grad") { | |||||
| } else if (prim.name() == "ReLU6Grad") { | |||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| } else if (prim.name() == "HSigmoidGrad") { | |||||
| attr->type = schema::ActivationType_HSIGMOID; | |||||
| } else if (prim.name() == "HSwishGrad") { | |||||
| attr->type = schema::ActivationType_HSWISH; | |||||
| } | } | ||||
| attr->alpha = 0; // alpha; | attr->alpha = 0; // alpha; | ||||
| this->primitive_->value.value = attr.release(); | this->primitive_->value.value = attr.release(); | ||||
| @@ -0,0 +1,91 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/adam.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; } | |||||
| int Adam::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Adam; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Adam) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = std::make_unique<schema::AdamT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | |||||
| this->primitive_->value.value = attr.release(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); } | |||||
| int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Adam(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Adam return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||||
| if (10 != inputs.size()) { | |||||
| MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() || | |||||
| inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 || | |||||
| inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 || | |||||
| inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) { | |||||
| MS_LOG(ERROR) << "error input data size!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!outputs.empty()) { | |||||
| auto *out = outputs.front(); | |||||
| MS_ASSERT(out != nullptr); | |||||
| out->set_data_type(inputs[0]->data_type()); | |||||
| out->SetFormat(inputs[0]->GetFormat()); | |||||
| out->set_shape({1}); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Adam : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Adam, PrimitiveC); | |||||
| Adam() = default; | |||||
| explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| Adam() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| bool GetUseNesterov() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||||
| @@ -82,8 +82,11 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||||
| if (!GetInferFlag()) { | if (!GetInferFlag()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| output->set_shape(input->shape()); | |||||
| // make sure all elements have the same size or 1 (broadcasting) in all dimensions | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| if (inputs.at(i)->shape() != inputs.at(0)->shape()) { | |||||
| if (inputs.at(i)->shape().size() != inputs.at(0)->shape().size()) { | |||||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | ||||
| return RET_INPUT_TENSOR_ERROR; | return RET_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| @@ -93,7 +96,22 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||||
| } | } | ||||
| } | } | ||||
| output->set_shape(input->shape()); | |||||
| for (size_t d = 0; d < input->shape().size(); ++d) { | |||||
| int max_dim = input->shape().at(d); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if (inputs.at(i)->shape().at(d) > max_dim) { | |||||
| max_dim = inputs.at(i)->shape().at(d); | |||||
| } | |||||
| } | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if ((inputs.at(0)->shape().at(d) != max_dim) && (inputs.at(0)->shape().at(d) != 1)) { | |||||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| } | |||||
| output->shape()[d] = max_dim; // set the biggest dimension in the output tensor | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -18,7 +18,6 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; } | float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; } | ||||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; } | |||||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; } | bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; } | ||||
| int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| @@ -41,7 +40,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale")); | attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale")); | ||||
| attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking")); | |||||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | ||||
| this->primitive_->value.value = attr.release(); | this->primitive_->value.value = attr.release(); | ||||
| @@ -54,7 +52,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||||
| } | } | ||||
| #else | #else | ||||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); } | float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); } | ||||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); } | |||||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); } | bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); } | ||||
| int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| @@ -65,7 +62,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||||
| MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; | MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov()); | |||||
| auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useNesterov()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -40,7 +40,6 @@ class ApplyMomentum : public PrimitiveC { | |||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| float GetGradientScale() const; | float GetGradientScale() const; | ||||
| bool GetUseLocking() const; | |||||
| bool GetUseNesterov() const; | bool GetUseNesterov() const; | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/assign.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Assign::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Assign; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Assign) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::AssignT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Assign(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Assign return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateAssign(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||||
| if (2 != inputs.size()) { | |||||
| MS_LOG(ERROR) << "Assign should have at least 5 input tensors"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) { | |||||
| MS_LOG(ERROR) << "error input data size!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!outputs.empty()) { | |||||
| auto *out = outputs.front(); | |||||
| MS_ASSERT(out != nullptr); | |||||
| out->set_data_type(inputs[0]->data_type()); | |||||
| out->SetFormat(inputs[0]->GetFormat()); | |||||
| out->set_shape({1}); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Assign : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Assign, PrimitiveC); | |||||
| Assign() = default; | |||||
| explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| Assign() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||||
| @@ -45,8 +45,8 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||||
| } | } | ||||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | ||||
| // FusedBatchNormGrad dows not get this attribute | // FusedBatchNormGrad dows not get this attribute | ||||
| if (prim.GetAttr("eps") != nullptr) { | |||||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||||
| if (prim.GetAttr("epsilon") != nullptr) { | |||||
| attr->eps = GetValue<float>(prim.GetAttr("epsilon")); | |||||
| } | } | ||||
| this->primitive_->value.value = attr; | this->primitive_->value.value = attr; | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "src/ops/conv2d_grad_input.h" | #include "src/ops/conv2d_grad_input.h" | ||||
| #include "src/ops/group_conv2d_grad_input.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| @@ -86,6 +86,9 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->group = GetValue<int>(prim.GetAttr("group")); | attr->group = GetValue<int>(prim.GetAttr("group")); | ||||
| if (attr->group > 1) { | |||||
| this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; | |||||
| } | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | auto format = GetValue<std::string>(prim.GetAttr("data_format")); | ||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| @@ -26,6 +26,30 @@ void Exp::SetShift(float shift) { this->primitive_->value.AsExp()->shift = shift | |||||
| float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; } | float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; } | ||||
| float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; } | float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; } | ||||
| float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; } | float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; } | ||||
| int Exp::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Exp; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Exp) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::ExpT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| @@ -33,6 +33,7 @@ class Exp : public PrimitiveC { | |||||
| void SetBase(float base); | void SetBase(float base); | ||||
| void SetShift(float shift); | void SetShift(float shift); | ||||
| void SetScale(float scale); | void SetScale(float scale); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Exp() = default; | Exp() = default; | ||||
| @@ -0,0 +1,172 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/group_conv2d_grad_input.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; } | |||||
| int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; } | |||||
| int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; } | |||||
| int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; } | |||||
| int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; } | |||||
| int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; } | |||||
| int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; } | |||||
| int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; } | |||||
| int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; } | |||||
| int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; } | |||||
| int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; } | |||||
| int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; } | |||||
| int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; } | |||||
| int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; } | |||||
| int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; } | |||||
| bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; } | |||||
| int GroupConv2DGradInput::GetActivationType() const { | |||||
| return this->primitive_->value.AsGroupConv2DGradInput()->activationType; | |||||
| } | |||||
| void GroupConv2DGradInput::SetFormat(int format) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format; | |||||
| } | |||||
| void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; } | |||||
| void GroupConv2DGradInput::SetChannelIn(int channel_in) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in; | |||||
| } | |||||
| void GroupConv2DGradInput::SetChannelOut(int channel_out) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out; | |||||
| } | |||||
| void GroupConv2DGradInput::SetKernelW(int kernel_w) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w; | |||||
| } | |||||
| void GroupConv2DGradInput::SetKernelH(int kernel_h) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h; | |||||
| } | |||||
| void GroupConv2DGradInput::SetStrideW(int stride_w) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w; | |||||
| } | |||||
| void GroupConv2DGradInput::SetStrideH(int stride_h) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h; | |||||
| } | |||||
| void GroupConv2DGradInput::SetPadMode(int pad_mode) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode; | |||||
| } | |||||
| void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; } | |||||
| void GroupConv2DGradInput::SetPadDown(int pad_down) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down; | |||||
| } | |||||
| void GroupConv2DGradInput::SetPadLeft(int pad_left) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left; | |||||
| } | |||||
| void GroupConv2DGradInput::SetPadRight(int pad_right) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right; | |||||
| } | |||||
| void GroupConv2DGradInput::SetDilateW(int dilate_w) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w; | |||||
| } | |||||
| void GroupConv2DGradInput::SetDilateH(int dilate_h) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h; | |||||
| } | |||||
| void GroupConv2DGradInput::SetHasBias(bool has_bias) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias; | |||||
| } | |||||
| void GroupConv2DGradInput::SetActivationType(int activation_type) { | |||||
| this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type; | |||||
| } | |||||
| #else | |||||
| int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_GroupConv2DGradInput(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateGroupConv2DGradInput( | |||||
| *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), | |||||
| attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), | |||||
| attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); } | |||||
| int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); } | |||||
| int GroupConv2DGradInput::GetChannelIn() const { | |||||
| return this->primitive_->value_as_GroupConv2DGradInput()->channelIn(); | |||||
| } | |||||
| int GroupConv2DGradInput::GetChannelOut() const { | |||||
| return this->primitive_->value_as_GroupConv2DGradInput()->channelOut(); | |||||
| } | |||||
| int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); } | |||||
| int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); } | |||||
| int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); } | |||||
| int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); } | |||||
| int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); } | |||||
| int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); } | |||||
| int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); } | |||||
| int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); } | |||||
| int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); } | |||||
| int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); } | |||||
| int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); } | |||||
| bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); } | |||||
| int GroupConv2DGradInput::GetActivationType() const { | |||||
| return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); | |||||
| } | |||||
| #endif | |||||
| int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | |||||
| if (3 != inputs.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (1 != outputs.size()) { | |||||
| MS_LOG(ERROR) << "Conv2d Grad input should have one output"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *in0 = inputs.at(0); | |||||
| auto *in = inputs.at(2); | |||||
| MS_ASSERT(out != nullptr); | |||||
| std::vector<int> output_shape; | |||||
| int *out_shape = reinterpret_cast<int *>(in->MutableData()); | |||||
| int new_size = in->ElementsNum(); | |||||
| if (in0->GetFormat() == in->GetFormat()) { | |||||
| for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]); | |||||
| } else { | |||||
| if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) { | |||||
| output_shape.push_back(out_shape[0]); | |||||
| output_shape.push_back(out_shape[2]); | |||||
| output_shape.push_back(out_shape[3]); | |||||
| output_shape.push_back(out_shape[1]); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Shape covnert is not supported"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| auto *out = outputs.at(0); | |||||
| MS_ASSERT(out != nullptr); | |||||
| out->set_shape(output_shape); | |||||
| out->set_data_type(in0->data_type()); | |||||
| out->SetFormat(in0->GetFormat()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class GroupConv2DGradInput : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC); | |||||
| GroupConv2DGradInput() = default; | |||||
| explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetFormat(int format); | |||||
| void SetGroup(int group); | |||||
| void SetChannelIn(int channel_in); | |||||
| void SetChannelOut(int channel_out); | |||||
| void SetKernelW(int kernel_w); | |||||
| void SetKernelH(int kernel_h); | |||||
| void SetStrideW(int stride_w); | |||||
| void SetStrideH(int stride_h); | |||||
| void SetPadMode(int pad_mode); | |||||
| void SetPadUp(int pad_up); | |||||
| void SetPadDown(int pad_down); | |||||
| void SetPadLeft(int pad_left); | |||||
| void SetPadRight(int pad_right); | |||||
| void SetDilateW(int dilate_w); | |||||
| void SetDilateH(int dilate_h); | |||||
| void SetHasBias(bool has_bias); | |||||
| void SetActivationType(int activation_type); | |||||
| #else | |||||
| GroupConv2DGradInput() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| int GetFormat() const; | |||||
| int GetGroup() const; | |||||
| int GetChannelIn() const; | |||||
| int GetChannelOut() const; | |||||
| int GetKernelW() const; | |||||
| int GetKernelH() const; | |||||
| int GetStrideW() const; | |||||
| int GetStrideH() const; | |||||
| int GetPadMode() const; | |||||
| int GetPadUp() const; | |||||
| int GetPadDown() const; | |||||
| int GetPadLeft() const; | |||||
| int GetPadRight() const; | |||||
| int GetDilateW() const; | |||||
| int GetDilateH() const; | |||||
| bool GetHasBias() const; | |||||
| int GetActivationType() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||||
| @@ -18,7 +18,31 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Neg::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Neg; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Neg) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::NegT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| MS_ASSERT(fbb != nullptr); | MS_ASSERT(fbb != nullptr); | ||||
| @@ -31,6 +31,7 @@ class Neg : public ArithmeticSelf { | |||||
| MS_DECLARE_PARENT(Neg, ArithmeticSelf); | MS_DECLARE_PARENT(Neg, ArithmeticSelf); | ||||
| Neg() = default; | Neg() = default; | ||||
| explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Neg() = default; | Neg() = default; | ||||
| @@ -23,6 +23,37 @@ int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; } | |||||
| void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; } | void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; } | ||||
| int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_OneHot; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_OneHot) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::OneHotT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->axis = -1; | |||||
| if (prim.GetAttr("axis") != nullptr) { | |||||
| attr->axis = GetValue<int>(prim.GetAttr("axis")); | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } | int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } | ||||
| @@ -32,7 +32,7 @@ class OneHot : public PrimitiveC { | |||||
| OneHot() = default; | OneHot() = default; | ||||
| explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetAxis(int axis); | void SetAxis(int axis); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| OneHot() = default; | OneHot() = default; | ||||
| @@ -144,6 +144,7 @@ | |||||
| #include "src/ops/pooling_grad.h" | #include "src/ops/pooling_grad.h" | ||||
| #include "src/ops/conv2d_grad_filter.h" | #include "src/ops/conv2d_grad_filter.h" | ||||
| #include "src/ops/conv2d_grad_input.h" | #include "src/ops/conv2d_grad_input.h" | ||||
| #include "src/ops/group_conv2d_grad_input.h" | |||||
| #include "src/ops/power_grad.h" | #include "src/ops/power_grad.h" | ||||
| #include "src/ops/softmax_cross_entropy.h" | #include "src/ops/softmax_cross_entropy.h" | ||||
| #include "src/ops/bn_grad.h" | #include "src/ops/bn_grad.h" | ||||
| @@ -152,6 +153,8 @@ | |||||
| #include "src/ops/flatten_grad.h" | #include "src/ops/flatten_grad.h" | ||||
| #include "src/ops/log_grad.h" | #include "src/ops/log_grad.h" | ||||
| #include "src/ops/sgd.h" | #include "src/ops/sgd.h" | ||||
| #include "src/ops/adam.h" | |||||
| #include "src/ops/assign.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -367,7 +370,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect | |||||
| std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | ||||
| const schema::QuantType &quantType) { | const schema::QuantType &quantType) { | ||||
| const auto &op_type = prim.name(); | const auto &op_type = prim.name(); | ||||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { | |||||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | |||||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | return NewPrimitiveC<Activation>(prim, inputs, quantType); | ||||
| } else if (op_type == "AddN") { | } else if (op_type == "AddN") { | ||||
| return NewPrimitiveC<AddN>(prim, inputs, quantType); | return NewPrimitiveC<AddN>(prim, inputs, quantType); | ||||
| @@ -413,6 +416,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | return NewPrimitiveC<Reduce>(prim, inputs, quantType); | ||||
| } else if (op_type == "Reshape") { | } else if (op_type == "Reshape") { | ||||
| return NewPrimitiveC<Reshape>(prim, inputs, quantType); | return NewPrimitiveC<Reshape>(prim, inputs, quantType); | ||||
| } else if (op_type == "Slice") { | |||||
| return NewPrimitiveC<Slice>(prim, inputs, quantType); | |||||
| } else if (op_type == "Squeeze") { | |||||
| return NewPrimitiveC<Squeeze>(prim, inputs, quantType); | |||||
| } else if (op_type == "TensorAdd") { | } else if (op_type == "TensorAdd") { | ||||
| return NewPrimitiveC<Add>(prim, inputs, quantType); | return NewPrimitiveC<Add>(prim, inputs, quantType); | ||||
| } else if (op_type == "Transpose") { | } else if (op_type == "Transpose") { | ||||
| @@ -421,6 +428,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Elu>(prim, inputs, quantType); | return NewPrimitiveC<Elu>(prim, inputs, quantType); | ||||
| } else if (op_type == "Log") { | } else if (op_type == "Log") { | ||||
| return NewPrimitiveC<Log>(prim, inputs, quantType); | return NewPrimitiveC<Log>(prim, inputs, quantType); | ||||
| } else if (op_type == "Exp") { | |||||
| return NewPrimitiveC<Exp>(prim, inputs, quantType); | |||||
| } else if (op_type == "Neg") { | |||||
| return NewPrimitiveC<Neg>(prim, inputs, quantType); | |||||
| } else if (op_type == "DeConv2D") { | } else if (op_type == "DeConv2D") { | ||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| } else if (op_type == "tuple_getitem") { | } else if (op_type == "tuple_getitem") { | ||||
| @@ -435,6 +446,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<Maximum>(prim, inputs, quantType); | return NewPrimitiveC<Maximum>(prim, inputs, quantType); | ||||
| } else if (op_type == "Split") { | } else if (op_type == "Split") { | ||||
| return NewPrimitiveC<Split>(prim, inputs, quantType); | return NewPrimitiveC<Split>(prim, inputs, quantType); | ||||
| } else if (op_type == "OneHot") { | |||||
| return NewPrimitiveC<OneHot>(prim, inputs, quantType); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | ||||
| @@ -445,7 +458,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | ||||
| } else if (op_type == "Depend") { | } else if (op_type == "Depend") { | ||||
| return NewPrimitiveC<Depend>(prim, inputs, quantType); | return NewPrimitiveC<Depend>(prim, inputs, quantType); | ||||
| } else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) { | |||||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | |||||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | |||||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | ||||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) { | } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) { | ||||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | ||||
| @@ -465,6 +479,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "SGD") { | } else if (op_type == "SGD") { | ||||
| return NewPrimitiveC<Sgd>(prim, inputs, quantType); | return NewPrimitiveC<Sgd>(prim, inputs, quantType); | ||||
| } else if (op_type == "Adam") { | |||||
| return NewPrimitiveC<Adam>(prim, inputs, quantType); | |||||
| } else if (op_type == "Assign") { | |||||
| return NewPrimitiveC<Assign>(prim, inputs, quantType); | |||||
| #else | #else | ||||
| } else if (op_type == "Conv2DBackpropInput") { | } else if (op_type == "Conv2DBackpropInput") { | ||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| @@ -686,6 +704,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new Dropout(primitive); | return new Dropout(primitive); | ||||
| case schema::PrimitiveType_Neg: | case schema::PrimitiveType_Neg: | ||||
| return new Neg(primitive); | return new Neg(primitive); | ||||
| case schema::PrimitiveType_RealDiv: | |||||
| return new RealDiv(primitive); | |||||
| case schema::PrimitiveType_LshProjection: | case schema::PrimitiveType_LshProjection: | ||||
| return new LshProjection(primitive); | return new LshProjection(primitive); | ||||
| case schema::PrimitiveType_HashtableLookup: | case schema::PrimitiveType_HashtableLookup: | ||||
| @@ -710,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new Conv2DGradFilter(primitive); | return new Conv2DGradFilter(primitive); | ||||
| case schema::PrimitiveType_Conv2DGradInput: | case schema::PrimitiveType_Conv2DGradInput: | ||||
| return new Conv2DGradInput(primitive); | return new Conv2DGradInput(primitive); | ||||
| case schema::PrimitiveType_GroupConv2DGradInput: | |||||
| return new GroupConv2DGradInput(primitive); | |||||
| case schema::PrimitiveType_BiasGrad: | case schema::PrimitiveType_BiasGrad: | ||||
| return new BiasGrad(primitive); | return new BiasGrad(primitive); | ||||
| case schema::PrimitiveType_ApplyMomentum: | case schema::PrimitiveType_ApplyMomentum: | ||||
| @@ -738,8 +760,11 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new LogGrad(primitive); | return new LogGrad(primitive); | ||||
| case schema::PrimitiveType_Sgd: | case schema::PrimitiveType_Sgd: | ||||
| return new Sgd(primitive); | return new Sgd(primitive); | ||||
| case schema::PrimitiveType_Adam: | |||||
| return new Adam(primitive); | |||||
| case schema::PrimitiveType_Assign: | |||||
| return new Assign(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | ||||
| break; | break; | ||||
| @@ -958,6 +983,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||||
| return NewPrimitiveC<DetectionPostProcess>(primitive); | return NewPrimitiveC<DetectionPostProcess>(primitive); | ||||
| case schema::PrimitiveType_Dropout: | case schema::PrimitiveType_Dropout: | ||||
| return NewPrimitiveC<Dropout>(primitive); | return NewPrimitiveC<Dropout>(primitive); | ||||
| case schema::PrimitiveType_RealDiv: | |||||
| return NewPrimitiveC<RealDiv>(primitive); | |||||
| case schema::PrimitiveType_LshProjection: | case schema::PrimitiveType_LshProjection: | ||||
| return NewPrimitiveC<LshProjection>(primitive); | return NewPrimitiveC<LshProjection>(primitive); | ||||
| case schema::PrimitiveType_HashtableLookup: | case schema::PrimitiveType_HashtableLookup: | ||||
| @@ -982,6 +1009,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||||
| return NewPrimitiveC<Conv2DGradFilter>(primitive); | return NewPrimitiveC<Conv2DGradFilter>(primitive); | ||||
| case schema::PrimitiveType_Conv2DGradInput: | case schema::PrimitiveType_Conv2DGradInput: | ||||
| return NewPrimitiveC<Conv2DGradInput>(primitive); | return NewPrimitiveC<Conv2DGradInput>(primitive); | ||||
| case schema::PrimitiveType_GroupConv2DGradInput: | |||||
| return NewPrimitiveC<GroupConv2DGradInput>(primitive); | |||||
| case schema::PrimitiveType_BiasGrad: | case schema::PrimitiveType_BiasGrad: | ||||
| return NewPrimitiveC<BiasGrad>(primitive); | return NewPrimitiveC<BiasGrad>(primitive); | ||||
| case schema::PrimitiveType_ApplyMomentum: | case schema::PrimitiveType_ApplyMomentum: | ||||
| @@ -1004,6 +1033,10 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||||
| return NewPrimitiveC<LogGrad>(primitive); | return NewPrimitiveC<LogGrad>(primitive); | ||||
| case schema::PrimitiveType_Sgd: | case schema::PrimitiveType_Sgd: | ||||
| return NewPrimitiveC<Sgd>(primitive); | return NewPrimitiveC<Sgd>(primitive); | ||||
| case schema::PrimitiveType_Adam: | |||||
| return NewPrimitiveC<Adam>(primitive); | |||||
| case schema::PrimitiveType_Assign: | |||||
| return NewPrimitiveC<Assign>(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <set> | #include <set> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -48,7 +48,9 @@ constexpr int kAnfPopulaterTwo = 2; | |||||
| constexpr int kAnfPopulaterThree = 3; | constexpr int kAnfPopulaterThree = 3; | ||||
| static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, | static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, | ||||
| {"ReLU6", schema::ActivationType_RELU6}, | {"ReLU6", schema::ActivationType_RELU6}, | ||||
| {"Sigmoid", schema::ActivationType_SIGMOID}}; | |||||
| {"Sigmoid", schema::ActivationType_SIGMOID}, | |||||
| {"HSwish", schema::ActivationType_HSWISH}, | |||||
| {"HSigmoid", schema::ActivationType_HSIGMOID}}; | |||||
| class PrimitiveC : public mindspore::Primitive { | class PrimitiveC : public mindspore::Primitive { | ||||
| public: | public: | ||||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | ||||
| @@ -213,4 +215,4 @@ class PrimitiveC { | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||||
| @@ -44,7 +44,14 @@ int RealDiv::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||||
| } | } | ||||
| #else | #else | ||||
| int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateRank(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RealDiv, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | #endif | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,10 +36,7 @@ class RealDiv : public Arithmetic { | |||||
| #else | #else | ||||
| RealDiv() = default; | RealDiv() = default; | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { | |||||
| return RET_ERROR; | |||||
| } | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -35,6 +35,66 @@ void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = | |||||
| void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } | void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } | ||||
| void Slice::SetSize(const std::vector<int> &size) { this->primitive_->value.AsSlice()->size = size; } | void Slice::SetSize(const std::vector<int> &size) { this->primitive_->value.AsSlice()->size = size; } | ||||
| int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Slice; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Slice) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::SliceT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs.size() >= kAnfPopulaterThree) { | |||||
| auto beginNode = inputs[kAnfPopulaterOne]; | |||||
| MS_ASSERT(beginNode != nullptr); | |||||
| if (beginNode->isa<ValueNode>()) { | |||||
| auto valueNode = beginNode->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(valueNode != nullptr); | |||||
| auto value = valueNode->value(); | |||||
| MS_ASSERT(value != nullptr); | |||||
| if (value->isa<ValueTuple>()) { | |||||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||||
| MS_ASSERT(valTuplPtr != nullptr); | |||||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||||
| MS_ASSERT(elem != nullptr); | |||||
| attr->begin.emplace_back(elem->value()); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto sizeNode = inputs[kAnfPopulaterTwo]; | |||||
| MS_ASSERT(sizeNode != nullptr); | |||||
| if (sizeNode->isa<ValueNode>()) { | |||||
| auto valueNode = sizeNode->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(valueNode != nullptr); | |||||
| auto value = valueNode->value(); | |||||
| MS_ASSERT(value != nullptr); | |||||
| if (value->isa<ValueTuple>()) { | |||||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||||
| MS_ASSERT(valTuplPtr != nullptr); | |||||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||||
| MS_ASSERT(elem != nullptr); | |||||
| attr->size.emplace_back(elem->value()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } | int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } | ||||
| @@ -46,10 +106,12 @@ std::vector<int> Slice::GetSize() const { | |||||
| auto fb_vector = this->primitive_->value_as_Slice()->size(); | auto fb_vector = this->primitive_->value_as_Slice()->size(); | ||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| std::vector<int> Slice::GetAxes() const { | std::vector<int> Slice::GetAxes() const { | ||||
| auto fb_vector = this->primitive_->value_as_Slice()->axes(); | auto fb_vector = this->primitive_->value_as_Slice()->axes(); | ||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != fbb); | MS_ASSERT(nullptr != fbb); | ||||
| @@ -90,7 +152,7 @@ std::vector<int> Slice::GetPostProcessBegin() const { return this->begin; } | |||||
| std::vector<int> Slice::GetPostProcessSize() const { return this->size; } | std::vector<int> Slice::GetPostProcessSize() const { return this->size; } | ||||
| int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { | |||||
| if (inputs.size() < kSliceInputNum || outputs.size() != kSliceOutputNum) { | |||||
| MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); | MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| @@ -35,6 +35,7 @@ class Slice : public PrimitiveC { | |||||
| void SetFormat(int format); | void SetFormat(int format); | ||||
| void SetBegin(const std::vector<int> &begin); | void SetBegin(const std::vector<int> &begin); | ||||
| void SetSize(const std::vector<int> &size); | void SetSize(const std::vector<int> &size); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Slice() = default; | Slice() = default; | ||||
| @@ -56,4 +57,4 @@ class Slice : public PrimitiveC { | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||||
| @@ -23,6 +23,35 @@ std::vector<int> Squeeze::GetAxis() const { return this->primitive_->value.AsSqu | |||||
| void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsSqueeze()->axis = axis; } | void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsSqueeze()->axis = axis; } | ||||
| int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Squeeze; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Squeeze) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::SqueezeT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| std::vector<int> Squeeze::GetAxis() const { | std::vector<int> Squeeze::GetAxis() const { | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| @@ -33,6 +33,7 @@ class Squeeze : public PrimitiveC { | |||||
| Squeeze() = default; | Squeeze() = default; | ||||
| explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetAxis(const std::vector<int> &axis); | void SetAxis(const std::vector<int> &axis); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Squeeze() = default; | Squeeze() = default; | ||||
| @@ -45,4 +46,4 @@ class Squeeze : public PrimitiveC { | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||||
| @@ -24,7 +24,7 @@ std::vector<int> Tile::GetMultiples() const { return this->primitive_->value.AsT | |||||
| void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive_->value.AsTile()->multiples = multiples; } | void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive_->value.AsTile()->multiples = multiples; } | ||||
| std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->multiples; } | |||||
| std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->dims; } | |||||
| void Tile::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsTile()->dims = dims; } | void Tile::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsTile()->dims = dims; } | ||||
| @@ -42,11 +42,32 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| this->primitive_->value.value = new (std::nothrow) schema::TileT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::TileT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (inputs.size() == kAnfPopulaterTwo) { | |||||
| auto inputNode = inputs[kAnfPopulaterOne]; | |||||
| MS_ASSERT(inputNode != nullptr); | |||||
| if (inputNode->isa<ValueNode>()) { | |||||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(valueNode != nullptr); | |||||
| auto value = valueNode->value(); | |||||
| MS_ASSERT(value != nullptr); | |||||
| if (value->isa<ValueTuple>()) { | |||||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||||
| MS_ASSERT(valTuplPtr != nullptr); | |||||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||||
| MS_ASSERT(elem != nullptr); | |||||
| attr->multiples.emplace_back(elem->value()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -103,15 +124,19 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||||
| MS_ASSERT(tile_prim != nullptr); | MS_ASSERT(tile_prim != nullptr); | ||||
| std::vector<int> out_shape; | std::vector<int> out_shape; | ||||
| std::vector<int> multiples; | |||||
| for (size_t i = 0; i < GetMultiples().size(); ++i) { | |||||
| multiples.push_back(GetMultiples()[i]); | |||||
| std::vector<int> multiples = GetMultiples(); | |||||
| const size_t in_dims = input->shape().size(); | |||||
| const size_t delta_dims = in_dims - multiples.size(); | |||||
| size_t i = 0; | |||||
| for (; i < delta_dims; ++i) { | |||||
| int tmp = input->shape()[i]; | |||||
| out_shape.push_back(tmp); | |||||
| } | } | ||||
| for (size_t i = 0; i < input->shape().size(); ++i) { | |||||
| int tmp = input->shape()[i] * multiples[i]; | |||||
| for (; i < in_dims; ++i) { | |||||
| int tmp = input->shape()[i] * (multiples[i - delta_dims]); | |||||
| out_shape.push_back(tmp); | out_shape.push_back(tmp); | ||||
| } | } | ||||
| output->set_shape(out_shape); | output->set_shape(out_shape); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -43,6 +43,7 @@ | |||||
| #include "src/ops/add.h" | #include "src/ops/add.h" | ||||
| #include "src/ops/sub.h" | #include "src/ops/sub.h" | ||||
| #include "src/ops/div.h" | #include "src/ops/div.h" | ||||
| #include "src/ops/real_div.h" | |||||
| #include "src/ops/bias_add.h" | #include "src/ops/bias_add.h" | ||||
| #include "src/ops/expand_dims.h" | #include "src/ops/expand_dims.h" | ||||
| #include "src/ops/full_connection.h" | #include "src/ops/full_connection.h" | ||||
| @@ -1680,6 +1681,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||||
| populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_RealDiv] = PopulateArithmetic; | |||||
| populate_parameter_funcs_[schema::PrimitiveType_LogicalAnd] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_LogicalAnd] = PopulateArithmetic; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_LogicalOr] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_LogicalOr] = PopulateArithmetic; | ||||
| populate_parameter_funcs_[schema::PrimitiveType_Equal] = PopulateArithmetic; | populate_parameter_funcs_[schema::PrimitiveType_Equal] = PopulateArithmetic; | ||||
| @@ -24,6 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::ActivationType_HSIGMOID; | |||||
| using mindspore::schema::ActivationType_HSWISH; | using mindspore::schema::ActivationType_HSWISH; | ||||
| using mindspore::schema::ActivationType_LEAKY_RELU; | using mindspore::schema::ActivationType_LEAKY_RELU; | ||||
| using mindspore::schema::ActivationType_RELU; | using mindspore::schema::ActivationType_RELU; | ||||
| @@ -57,6 +58,8 @@ int ActivationCPUKernel::DoActivation(int task_id) { | |||||
| error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); | error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); | ||||
| } else if (type_ == schema::ActivationType_HSWISH) { | } else if (type_ == schema::ActivationType_HSWISH) { | ||||
| error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); | error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); | ||||
| } else if (type_ == schema::ActivationType_HSIGMOID) { | |||||
| error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_HARD_TANH) { | } else if (type_ == schema::ActivationType_HARD_TANH) { | ||||
| error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | ||||
| } else { | } else { | ||||
| @@ -60,14 +60,34 @@ int AddNCPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| elements_num_ = in_tensors_[0]->ElementsNum(); | |||||
| elements_num_ = out_tensors_[0]->ElementsNum(); | |||||
| auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | ||||
| auto input1_data = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | auto input1_data = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | ||||
| auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | ||||
| if (static_cast<int>(elements_num_) < op_parameter_->thread_num_) { | if (static_cast<int>(elements_num_) < op_parameter_->thread_num_) { | ||||
| ElementAdd(input0_data, input1_data, output_data, elements_num_); | |||||
| if (in_tensors_[0]->shape() == in_tensors_[1]->shape()) { | |||||
| ElementAdd(input0_data, input1_data, output_data, elements_num_); | |||||
| } else { | |||||
| ArithmeticParameter param; | |||||
| param.in_elements_num0_ = in_tensors_[0]->ElementsNum(); | |||||
| param.in_elements_num1_ = in_tensors_[1]->ElementsNum(); | |||||
| param.out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||||
| param.broadcasting_ = true; | |||||
| ElementOptAdd(input0_data, input1_data, output_data, elements_num_, ¶m); | |||||
| } | |||||
| for (size_t i = 2; i < in_tensors_.size(); ++i) { | for (size_t i = 2; i < in_tensors_.size(); ++i) { | ||||
| ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); | |||||
| if (in_tensors_[i]->shape() == out_tensors_[0]->shape()) { | |||||
| ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); | |||||
| } else { | |||||
| ArithmeticParameter param; | |||||
| param.in_elements_num0_ = in_tensors_[i]->ElementsNum(); | |||||
| param.in_elements_num1_ = out_tensors_[0]->ElementsNum(); | |||||
| param.out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||||
| param.broadcasting_ = true; | |||||
| ElementOptAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_, | |||||
| ¶m); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -104,6 +104,7 @@ int ArithmeticCPUKernel::ReSize() { | |||||
| } | } | ||||
| break; | break; | ||||
| case PrimitiveType_Div: | case PrimitiveType_Div: | ||||
| case PrimitiveType_RealDiv: | |||||
| switch (arithmeticParameter_->activation_type_) { | switch (arithmeticParameter_->activation_type_) { | ||||
| case schema::ActivationType_RELU: | case schema::ActivationType_RELU: | ||||
| arithmeticParameter_->broadcasting_ = false; | arithmeticParameter_->broadcasting_ = false; | ||||
| @@ -304,6 +305,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC | |||||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) | ||||
| @@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Maximum; | |||||
| using mindspore::schema::PrimitiveType_Minimum; | using mindspore::schema::PrimitiveType_Minimum; | ||||
| using mindspore::schema::PrimitiveType_Mul; | using mindspore::schema::PrimitiveType_Mul; | ||||
| using mindspore::schema::PrimitiveType_NotEqual; | using mindspore::schema::PrimitiveType_NotEqual; | ||||
| using mindspore::schema::PrimitiveType_RealDiv; | |||||
| using mindspore::schema::PrimitiveType_SquaredDifference; | using mindspore::schema::PrimitiveType_SquaredDifference; | ||||
| using mindspore::schema::PrimitiveType_Sub; | using mindspore::schema::PrimitiveType_Sub; | ||||
| @@ -99,6 +100,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| } | } | ||||
| break; | break; | ||||
| case PrimitiveType_Div: | case PrimitiveType_Div: | ||||
| case PrimitiveType_RealDiv: | |||||
| switch (arithmeticParameter_->activation_type_) { | switch (arithmeticParameter_->activation_type_) { | ||||
| case schema::ActivationType_RELU: | case schema::ActivationType_RELU: | ||||
| arithmetic_run_ = ElementDivRelu; | arithmetic_run_ = ElementDivRelu; | ||||
| @@ -48,7 +48,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| int length = in_tensors_.at(0)->ElementsNum(); | int length = in_tensors_.at(0)->ElementsNum(); | ||||
| int stride = UP_DIV(length, thread_count_); | |||||
| int stride = UP_DIV(length, 1); | |||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| auto error_code = RET_OK; | auto error_code = RET_OK; | ||||
| @@ -63,8 +63,9 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||||
| error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, | error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, | ||||
| output_addr + stride * task_id, param_act_grad_->alpha_); | output_addr + stride * task_id, param_act_grad_->alpha_); | ||||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | ||||
| // Sigmoid gets the input tensors in reverse order! | |||||
| error_code = | error_code = | ||||
| SigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | ||||
| error_code = | error_code = | ||||
| TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | ||||
| @@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel { | |||||
| explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||||
| : LiteKernel(param, inputs, outputs, ctx, primitive) { | |||||
| param_act_grad_ = reinterpret_cast<ActivationParameter *>(param); | param_act_grad_ = reinterpret_cast<ActivationParameter *>(param); | ||||
| } | } | ||||
| ~ActivationGradCPUKernel() override = default; | ~ActivationGradCPUKernel() override = default; | ||||
| @@ -38,7 +38,6 @@ class ActivationGradCPUKernel : public LiteKernel { | |||||
| int DoActivation(int task_id); | int DoActivation(int task_id); | ||||
| private: | private: | ||||
| int thread_count_; | |||||
| ActivationParameter *param_act_grad_; | ActivationParameter *param_act_grad_; | ||||
| }; | }; | ||||
| @@ -0,0 +1,118 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/adam.h" | |||||
| #include <cmath> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Adam; | |||||
| namespace mindspore::kernel { | |||||
| int AdamCPUKernel::ReSize() { return RET_OK; } | |||||
| int AdamCPUKernel::Execute(int task_id) { | |||||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto m = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| auto v = reinterpret_cast<float *>(in_tensors_[2]->MutableData()); | |||||
| auto beta1_power = reinterpret_cast<float *>(in_tensors_[3]->MutableData())[0]; | |||||
| auto beta2_power = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||||
| auto learning_rate = reinterpret_cast<float *>(in_tensors_[5]->MutableData())[0]; | |||||
| auto beta1 = reinterpret_cast<float *>(in_tensors_[6]->MutableData())[0]; | |||||
| auto beta2 = reinterpret_cast<float *>(in_tensors_[7]->MutableData())[0]; | |||||
| auto eps = reinterpret_cast<float *>(in_tensors_[8]->MutableData())[0]; | |||||
| auto gradient = reinterpret_cast<float *>(in_tensors_[9]->MutableData()); | |||||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||||
| if (adam_param_->use_nesterov_) { // Nadam | |||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); | |||||
| v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); | |||||
| auto g_hat = gradient[i] / (1 - beta1_power); | |||||
| auto m_hat = m[i] / (1 - beta1_power); | |||||
| auto v_hat = v[i] / (1 - beta2_power); | |||||
| auto m_tag = (1.f - beta1) * g_hat + beta1 * m_hat; | |||||
| weight[i] -= learning_rate * m_tag / (sqrtf(v_hat) + eps); | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < elem_num; ++i) { | |||||
| m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); | |||||
| v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); | |||||
| auto m_hat = m[i] / (1 - beta1_power); | |||||
| auto v_hat = v[i] / (1 - beta2_power); | |||||
| weight[i] -= learning_rate * m_hat / (sqrtf(v_hat) + eps); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AdamRun(void *cdata, int task_id) { | |||||
| auto Adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata); | |||||
| auto error_code = Adam_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Adam run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AdamCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AdamCPUKernel::Init() { return RET_OK; } | |||||
| kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const lite::PrimitiveC *primitive) { | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Adam); | |||||
| auto *kernel = new (std::nothrow) AdamCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adam, CpuAdamFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/fp32_grad/optimizer.h" | |||||
| namespace mindspore::kernel { | |||||
| class AdamCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| adam_param_ = reinterpret_cast<AdamParameter *>(parameter); | |||||
| } | |||||
| ~AdamCPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Execute(int task_id); | |||||
| private: | |||||
| AdamParameter *adam_param_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||||
| @@ -79,14 +79,7 @@ int ApplyMomentumCPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ApplyMomentumCPUKernel::Init() { | |||||
| // Only for test with uninitialized Data | |||||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; | |||||
| return RET_OK; | |||||
| } | |||||
| int ApplyMomentumCPUKernel::Init() { return RET_OK; } | |||||
| kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, | const std::vector<lite::Tensor *> &outputs, | ||||
| @@ -0,0 +1,91 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/assign.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Assign; | |||||
| namespace mindspore::kernel { | |||||
| int AssignCPUKernel::ReSize() { return RET_OK; } | |||||
| int AssignCPUKernel::Execute(int task_id) { | |||||
| auto x = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| auto y = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||||
| size_t size = in_tensors_[0]->Size(); | |||||
| memcpy(x, y, size); | |||||
| return RET_OK; | |||||
| } | |||||
| int AssignRun(void *cdata, int task_id) { | |||||
| auto Assign_kernel = reinterpret_cast<AssignCPUKernel *>(cdata); | |||||
| auto error_code = Assign_kernel->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "assign run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AssignCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AssignCPUKernel::Init() { return RET_OK; } | |||||
| kernel::LiteKernel *CpuAssignFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const lite::PrimitiveC *primitive) { | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Assign); | |||||
| auto *kernel = new (std::nothrow) AssignCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| auto ret = kernel->Init(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assign, CpuAssignFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/fp32_grad/optimizer.h" | |||||
| namespace mindspore::kernel { | |||||
| class AssignCPUKernel : public LiteKernel { | |||||
| public: | |||||
| explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~AssignCPUKernel() override {} | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Execute(int task_id); | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||||
| @@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Conv2DGradInput; | using mindspore::schema::PrimitiveType_Conv2DGradInput; | ||||
| using mindspore::schema::PrimitiveType_GroupConv2DGradInput; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ConvolutionGradInputCPUKernel::Init() { | int ConvolutionGradInputCPUKernel::Init() { | ||||
| @@ -134,7 +135,8 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te | |||||
| const kernel::KernelKey &desc, | const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput || | |||||
| desc.type == schema::PrimitiveType_GroupConv2DGradInput); | |||||
| auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); | auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| @@ -154,4 +156,6 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GroupConv2DGradInput, CpuConvGradInputFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -154,4 +154,6 @@ kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector<li | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -178,5 +178,4 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator( | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSparseSoftmaxCrossEntropyFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -41,6 +41,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||||
| } | } | ||||
| model->buf = reinterpret_cast<char *>(malloc(size)); | model->buf = reinterpret_cast<char *>(malloc(size)); | ||||
| if (model->buf == nullptr) { | if (model->buf == nullptr) { | ||||
| delete model; | |||||
| MS_LOG(ERROR) << "new inner model buf fail!"; | MS_LOG(ERROR) << "new inner model buf fail!"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -48,6 +49,8 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||||
| model->buf_size_ = size; | model->buf_size_ = size; | ||||
| auto meta_graph = schema::GetMetaGraph(model->buf); | auto meta_graph = schema::GetMetaGraph(model->buf); | ||||
| if (meta_graph == nullptr) { | if (meta_graph == nullptr) { | ||||
| delete model; | |||||
| free(model->buf); | |||||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | MS_LOG(ERROR) << "meta_graph is nullptr!"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "nnacl/fp32/activation.h" | #include "nnacl/fp32/activation.h" | ||||
| #include "src/ops/conv2d_grad_filter.h" | #include "src/ops/conv2d_grad_filter.h" | ||||
| #include "src/ops/conv2d_grad_input.h" | #include "src/ops/conv2d_grad_input.h" | ||||
| #include "src/ops/group_conv2d_grad_input.h" | |||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "src/ops/power_grad.h" | #include "src/ops/power_grad.h" | ||||
| #include "nnacl/power_parameter.h" | #include "nnacl/power_parameter.h" | ||||
| @@ -34,6 +35,7 @@ | |||||
| #include "src/ops/sgd.h" | #include "src/ops/sgd.h" | ||||
| #include "src/ops/bn_grad.h" | #include "src/ops/bn_grad.h" | ||||
| #include "nnacl/fp32_grad/batch_norm.h" | #include "nnacl/fp32_grad/batch_norm.h" | ||||
| #include "src/ops/adam.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -69,12 +71,29 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p | |||||
| reinterpret_cast<mindspore::lite::ApplyMomentum *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | reinterpret_cast<mindspore::lite::ApplyMomentum *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | ||||
| p->grad_scale_ = apply_momentum_primitive->GetGradientScale(); | p->grad_scale_ = apply_momentum_primitive->GetGradientScale(); | ||||
| p->use_locking_ = apply_momentum_primitive->GetUseLocking(); | |||||
| p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | ||||
| return reinterpret_cast<OpParameter *>(p); | return reinterpret_cast<OpParameter *>(p); | ||||
| } | } | ||||
| OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||||
| return nullptr; | |||||
| } | |||||
| AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter))); | |||||
| if (p == nullptr) { | |||||
| MS_LOG(ERROR) << "new AdamParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| p->op_parameter_.type_ = primitive->Type(); | |||||
| auto apply_momentum_primitive = | |||||
| reinterpret_cast<mindspore::lite::Adam *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | |||||
| return reinterpret_cast<OpParameter *>(p); | |||||
| } | |||||
| OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | ||||
| @@ -264,6 +283,47 @@ OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::Primit | |||||
| return reinterpret_cast<OpParameter *>(param); | return reinterpret_cast<OpParameter *>(param); | ||||
| } | } | ||||
| OpParameter *PopulateGroupConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||||
| return nullptr; | |||||
| } | |||||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new Param for conv grad filter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| param->op_parameter_.type_ = primitive->Type(); | |||||
| auto convg_primitive = | |||||
| reinterpret_cast<mindspore::lite::GroupConv2DGradInput *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| param->kernel_h_ = convg_primitive->GetKernelH(); | |||||
| param->kernel_w_ = convg_primitive->GetKernelW(); | |||||
| param->stride_h_ = convg_primitive->GetStrideH(); | |||||
| param->stride_w_ = convg_primitive->GetStrideW(); | |||||
| param->dilation_h_ = convg_primitive->GetDilateH(); | |||||
| param->dilation_w_ = convg_primitive->GetDilateW(); | |||||
| param->pad_u_ = convg_primitive->GetPadUp(); | |||||
| param->pad_d_ = convg_primitive->GetPadDown(); | |||||
| param->pad_l_ = convg_primitive->GetPadLeft(); | |||||
| param->pad_r_ = convg_primitive->GetPadRight(); | |||||
| param->group_ = convg_primitive->GetGroup(); | |||||
| param->act_type_ = ActType_No; | |||||
| switch (convg_primitive->GetActivationType()) { | |||||
| case schema::ActivationType_RELU: | |||||
| param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(param); | |||||
| } | |||||
| OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | ||||
| @@ -327,10 +387,13 @@ void PopulateTrainParameters() { | |||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_GroupConv2DGradInput, PopulateGroupConvolutionGradInputParameter); | |||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | ||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Adam, PopulateAdamParameter); | |||||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Assign, DefaultPopulateParameter); | |||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -104,23 +104,15 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session: | |||||
| for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast<lite::Tensor *>(ms_tensor))); | for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast<lite::Tensor *>(ms_tensor))); | ||||
| if (train_mode_) return lite::LiteSession::RunGraph(before, after); | if (train_mode_) return lite::LiteSession::RunGraph(before, after); | ||||
| // object is expected to run only inference part of graph | |||||
| // prepare a list of kernels till the loss function -- temporary solution | |||||
| std::vector<kernel::LiteKernel *> inference_kernels; | |||||
| for (auto kernel : this->kernels_) { | |||||
| if (IsLossKernel(kernel)) break; | |||||
| inference_kernels.push_back(kernel); | |||||
| } | |||||
| if (this->context_ == nullptr) { | if (this->context_ == nullptr) { | ||||
| MS_LOG(ERROR) << "context is null"; | MS_LOG(ERROR) << "context is null"; | ||||
| return lite::RET_NULL_PTR; | return lite::RET_NULL_PTR; | ||||
| } | } | ||||
| lite::Executor executor; | lite::Executor executor; | ||||
| if (before == nullptr && after == nullptr) { | if (before == nullptr && after == nullptr) { | ||||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get()); | |||||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get()); | |||||
| } else { | } else { | ||||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get(), before, | |||||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get(), before, | |||||
| after); | after); | ||||
| } | } | ||||
| } | } | ||||
| @@ -173,6 +165,38 @@ void TrainSession::Eval() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (inference_kernels_.size() == 0) { | |||||
| BuildInferenceKernelsMap(); | |||||
| } | |||||
| } | |||||
| void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) { | |||||
| if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not in vector | |||||
| v->push_back(kernel); | |||||
| for (auto in_node : kernel->in_kernels()) { | |||||
| BuildInferenceKernelsRecursive(in_node, v); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TrainSession::BuildInferenceKernelsMap() { | |||||
| std::vector<kernel::LiteKernel *> req_kernels; | |||||
| for (auto kernel : this->kernels_) { | |||||
| if (IsLossKernel(kernel)) { // For each loss in the system add backward tree | |||||
| for (auto in_node : kernel->in_kernels()) { | |||||
| BuildInferenceKernelsRecursive(in_node, &req_kernels); | |||||
| } | |||||
| } | |||||
| } | |||||
| inference_kernels_.clear(); | |||||
| for (auto kernel : this->kernels_) { | |||||
| if (std::find(req_kernels.begin(), req_kernels.end(), kernel) != req_kernels.end()) { | |||||
| inference_kernels_.push_back(kernel); | |||||
| } | |||||
| } | |||||
| if (inference_kernels_.size() == 0) { | |||||
| inference_kernels_ = this->kernels_; | |||||
| } | |||||
| } | } | ||||
| bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) { | bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) { | ||||
| @@ -82,12 +82,16 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||||
| protected: | protected: | ||||
| void AllocWorkSpace(); | void AllocWorkSpace(); | ||||
| bool IsLossKernel(kernel::LiteKernel *kernel); | |||||
| virtual std::vector<CreatorOp> ReplaceOps(); | virtual std::vector<CreatorOp> ReplaceOps(); | ||||
| virtual void RestoreOps(const std::vector<CreatorOp> &restore); | virtual void RestoreOps(const std::vector<CreatorOp> &restore); | ||||
| bool IsLossKernel(kernel::LiteKernel *kernel); | |||||
| virtual void BuildInferenceKernelsMap(); | |||||
| virtual void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels); | |||||
| TrainModel *model_ = nullptr; | TrainModel *model_ = nullptr; | ||||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | ||||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | ||||
| std::vector<kernel::LiteKernel *> inference_kernels_; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef TESTS_UT_COMMON_UT_COMMON_H_ | |||||
| #define TESTS_UT_COMMON_UT_COMMON_H_ | |||||
| #ifndef MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||||
| #define MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <fstream> | #include <fstream> | ||||
| @@ -37,11 +37,11 @@ class CommonTest : public testing::Test { | |||||
| void PrintData(std::string name, T *output_data, int size) { | void PrintData(std::string name, T *output_data, int size) { | ||||
| std::cout << "The " << name << " is as follows:" << std::endl; | std::cout << "The " << name << " is as follows:" << std::endl; | ||||
| if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { | if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { | ||||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||||
| for (int i = 0; i < std::min(size, 100); i++) { | |||||
| std::cout << static_cast<int>(output_data[i]) << " "; | std::cout << static_cast<int>(output_data[i]) << " "; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||||
| for (int i = 0; i < std::min(size, 100); i++) { | |||||
| std::cout << output_data[i] << " "; | std::cout << output_data[i] << " "; | ||||
| } | } | ||||
| } | } | ||||
| @@ -58,7 +58,7 @@ class CommonTest : public testing::Test { | |||||
| void CompareOutputInt8(int8_t *output_data, int8_t *correct_data, int size, float err_percent) { | void CompareOutputInt8(int8_t *output_data, int8_t *correct_data, int size, float err_percent) { | ||||
| int bias_count = 0; | int bias_count = 0; | ||||
| for (size_t i = 0; i < size; i++) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| int8_t diff = abs(output_data[i] - correct_data[i]); | int8_t diff = abs(output_data[i] - correct_data[i]); | ||||
| ASSERT_LE(diff, 1); | ASSERT_LE(diff, 1); | ||||
| if (diff == 1) { | if (diff == 1) { | ||||
| @@ -88,4 +88,4 @@ class CommonTest : public testing::Test { | |||||
| } | } | ||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // TESTS_UT_COMMON_UT_COMMON_H_ | |||||
| #endif // MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||||
| @@ -250,7 +250,7 @@ TEST_F(NetworkTest, tuning_layer) { | |||||
| auto label = std::make_unique<schema::TensorT>(); | auto label = std::make_unique<schema::TensorT>(); | ||||
| label->nodeType = schema::NodeType::NodeType_ValueNode; | label->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| label->format = schema::Format_NHWC; | label->format = schema::Format_NHWC; | ||||
| label->dataType = TypeId::kNumberTypeInt32; | |||||
| label->dataType = TypeId::kNumberTypeFloat32; | |||||
| label->dims = {BATCH_SIZE * NUM_CLASSES}; | label->dims = {BATCH_SIZE * NUM_CLASSES}; | ||||
| label->offset = -1; | label->offset = -1; | ||||
| meta_graph->allTensors.emplace_back(std::move(label)); | meta_graph->allTensors.emplace_back(std::move(label)); | ||||
| @@ -386,8 +386,10 @@ TEST_F(NetworkTest, tuning_layer) { | |||||
| auto labelTensor = inputs.at(1); | auto labelTensor = inputs.at(1); | ||||
| ASSERT_NE(nullptr, labelTensor); | ASSERT_NE(nullptr, labelTensor); | ||||
| ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum()); | ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum()); | ||||
| auto labels = reinterpret_cast<int *>(labelTensor->MutableData()); | |||||
| for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES; | |||||
| auto labels = reinterpret_cast<float *>(labelTensor->MutableData()); | |||||
| std::fill(labels, labels + labelTensor->ElementsNum(), 0.f); | |||||
| for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0; | |||||
| ret = session->RunGraph(); | ret = session->RunGraph(); | ||||
| ASSERT_EQ(lite::RET_OK, ret); | ASSERT_EQ(lite::RET_OK, ret); | ||||
| @@ -576,12 +578,12 @@ TEST_F(NetworkTest, lenetnet) { | |||||
| delete context; | delete context; | ||||
| ASSERT_EQ(res, 0); | ASSERT_EQ(res, 0); | ||||
| } | } | ||||
| #if 0 | |||||
| TEST_F(NetworkTest, retina_net) { | TEST_F(NetworkTest, retina_net) { | ||||
| char *buf = nullptr; | char *buf = nullptr; | ||||
| size_t net_size = 0; | size_t net_size = 0; | ||||
| std::string net = "./test_data/nets/retinaface1009.ms"; | |||||
| std::string net = "./test_data/nets/retinaface1.ms"; | |||||
| ReadFile(net.c_str(), &net_size, &buf); | ReadFile(net.c_str(), &net_size, &buf); | ||||
| // auto model = lite::TrainModel::Import(buf, net_size); | // auto model = lite::TrainModel::Import(buf, net_size); | ||||
| auto model = lite::Model::Import(buf, net_size); | auto model = lite::Model::Import(buf, net_size); | ||||
| @@ -598,26 +600,36 @@ TEST_F(NetworkTest, retina_net) { | |||||
| ASSERT_EQ(lite::RET_OK, ret); | ASSERT_EQ(lite::RET_OK, ret); | ||||
| // session->Eval(); | // session->Eval(); | ||||
| std::string in = "./test_data/nets/retinaface_input.f32"; | |||||
| std::string in = "./test_data/nets/test1.hwc_normalized_f32"; | |||||
| std::cout << "----- Output 0 -----" << std::endl; | std::cout << "----- Output 0 -----" << std::endl; | ||||
| std::string out = "./test_data/nets/retinaface_out_0.f32"; | |||||
| std::string out = "./test_data/nets/test1_loc.f32"; | |||||
| int final_res = 0; | |||||
| auto res = runNet(session, in, out, "448", true); | auto res = runNet(session, in, out, "448", true); | ||||
| ASSERT_EQ(res, 0); | |||||
| // ASSERT_EQ(res, 0); | |||||
| if (res != 0) { | |||||
| final_res = res; | |||||
| } | |||||
| std::cout << "----- Output 1 -----" << std::endl; | std::cout << "----- Output 1 -----" << std::endl; | ||||
| out = "./test_data/nets/retinaface_out_1.f32"; | |||||
| out = "./test_data/nets/test1_conf.f32"; | |||||
| res = runNet(session, in, out, "435", true); | res = runNet(session, in, out, "435", true); | ||||
| ASSERT_EQ(res, 0); | |||||
| // ASSERT_EQ(res, 0); | |||||
| if (res != 0) { | |||||
| final_res |= res; | |||||
| } | |||||
| std::cout << "----- Output 2 -----" << std::endl; | std::cout << "----- Output 2 -----" << std::endl; | ||||
| out = "./test_data/nets/retinaface_out_2.f32"; | |||||
| out = "./test_data/nets/test1_landms.f32"; | |||||
| res = runNet(session, in, out, "421", true); | res = runNet(session, in, out, "421", true); | ||||
| ASSERT_EQ(res, 0); | |||||
| if (res != 0) { | |||||
| final_res |= res; | |||||
| } | |||||
| ASSERT_EQ(final_res, 0); | |||||
| delete session; | delete session; | ||||
| delete context; | delete context; | ||||
| } | } | ||||
| #endif | |||||
| TEST_F(NetworkTest, mobileface_net) { | TEST_F(NetworkTest, mobileface_net) { | ||||
| char *buf = nullptr; | char *buf = nullptr; | ||||
| size_t net_size = 0; | size_t net_size = 0; | ||||
| @@ -41,10 +41,11 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||||
| std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; | std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; | ||||
| auto ll_labels = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); | auto ll_labels = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); | ||||
| auto labels = new int[6]; | |||||
| for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]); | |||||
| auto labels = new float[6 * 4]; | |||||
| std::fill(labels, labels + 6 * 4, 0.f); | |||||
| for (int i = 0; i < 6; i++) labels[i * 4 + ll_labels[i]] = 1.0; | |||||
| std::vector<int> dim_l({6}); | |||||
| std::vector<int> dim_l({6, 4}); | |||||
| lite::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); | lite::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); | ||||
| l_tensor.SetData(labels); | l_tensor.SetData(labels); | ||||
| @@ -274,9 +274,24 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s | |||||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | auto input_cnode = utils::cast<CNodePtr>(input_anode); | ||||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | ||||
| #ifndef SUPPORT_TRAIN | |||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||||
| } | |||||
| #else | |||||
| bool found = false; | |||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | if (node_id_map_.find(input_name) != node_id_map_.end()) { | ||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | ||||
| found = true; | |||||
| } | |||||
| if (found == false) { | |||||
| auto input_index_key = input_name + "_o:" + std::to_string(0); | |||||
| if (node_id_map_.find(input_index_key) != node_id_map_.end()) { | |||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); | |||||
| } | |||||
| } | } | ||||
| #endif | |||||
| } else { | } else { | ||||
| auto inputs = input_cnode->inputs(); | auto inputs = input_cnode->inputs(); | ||||
| if (inputs.size() != 3) { | if (inputs.size() != 3) { | ||||
| @@ -369,6 +384,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | auto typePtr = abstractTensor->element()->GetTypeTrack(); | ||||
| paramTensor->dataType = typePtr->type_id(); | paramTensor->dataType = typePtr->type_id(); | ||||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; | |||||
| #endif | |||||
| paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| auto data = value->cast<tensor::TensorPtr>(); | auto data = value->cast<tensor::TensorPtr>(); | ||||
| paramTensor->data.resize(data->Size()); | paramTensor->data.resize(data->Size()); | ||||
| @@ -505,7 +523,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||||
| node_id_map_[name] = meta_graphT->allTensors.size(); | node_id_map_[name] = meta_graphT->allTensors.size(); | ||||
| meta_graphT->allTensors.emplace_back(msTensor); | meta_graphT->allTensors.emplace_back(msTensor); | ||||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | ||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||||
| break; | break; | ||||
| #else | #else | ||||
| if (tuple->size() == 1) { | if (tuple->size() == 1) { | ||||
| @@ -678,6 +678,13 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| outputFuncGraph->set_return(return_node); | outputFuncGraph->set_return(return_node); | ||||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | MS_LOG(INFO) << "Construct funcgraph finined, all success."; | ||||
| } else { | } else { | ||||
| #ifdef SUPPORT_TRAIN | |||||
| auto ret_node = outputFuncGraph->get_return(); | |||||
| if (ret_node) { | |||||
| ret_node->add_input(cnode_ptr); | |||||
| return true; | |||||
| } | |||||
| #endif | |||||
| const onnx::ValueInfoProto &output_node = importProto.output(0); | const onnx::ValueInfoProto &output_node = importProto.output(0); | ||||
| const onnx::TypeProto &output_typeproto = output_node.type(); | const onnx::TypeProto &output_typeproto = output_node.type(); | ||||
| int output_type = output_typeproto.tensor_type().elem_type(); | int output_type = output_typeproto.tensor_type().elem_type(); | ||||
| @@ -687,7 +694,6 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | ||||
| inputs.clear(); | inputs.clear(); | ||||
| auto primReturn = std::make_unique<schema::PrimitiveT>(); | auto primReturn = std::make_unique<schema::PrimitiveT>(); | ||||
| MS_ASSERT(primReturn != nullptr); | MS_ASSERT(primReturn != nullptr); | ||||
| @@ -717,6 +723,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||||
| } | } | ||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | ||||
| CNodePtr cnode_ptr = nullptr; | CNodePtr cnode_ptr = nullptr; | ||||
| CNodePtr last_cnode_ptr = nullptr; | |||||
| int status = RET_OK; | int status = RET_OK; | ||||
| NoSupportOp::GetInstance()->SetFmkType("MINDIR"); | NoSupportOp::GetInstance()->SetFmkType("MINDIR"); | ||||
| for (int i = 0; i < importProto.node_size(); ++i) { | for (int i = 0; i < importProto.node_size(); ++i) { | ||||
| @@ -734,13 +741,35 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | ||||
| status = (status == RET_OK ? RET_NULL_PTR : status); | status = (status == RET_OK ? RET_NULL_PTR : status); | ||||
| } | } | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode_ptr->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||||
| status = RET_ERROR; | |||||
| } | |||||
| #ifdef SUPPORT_TRAIN | |||||
| if (primitive_c->Type() == schema::PrimitiveType_MakeTuple) { | |||||
| last_cnode_ptr = cnode_ptr; | |||||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||||
| status = RET_ERROR; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| return status; | return status; | ||||
| } | } | ||||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||||
| status = RET_ERROR; | |||||
| #ifdef SUPPORT_TRAIN | |||||
| if (last_cnode_ptr != cnode_ptr) { | |||||
| #else | |||||
| { | |||||
| #endif | |||||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||||
| status = RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -28,12 +28,14 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| schema::PrimitiveType_Conv2DGradFilter, | schema::PrimitiveType_Conv2DGradFilter, | ||||
| schema::PrimitiveType_Conv2DGradInput, | schema::PrimitiveType_Conv2DGradInput, | ||||
| schema::PrimitiveType_GroupConv2DGradInput, | |||||
| schema::PrimitiveType_PoolingGrad, | schema::PrimitiveType_PoolingGrad, | ||||
| schema::PrimitiveType_BiasGrad, | schema::PrimitiveType_BiasGrad, | ||||
| schema::PrimitiveType_BNGrad, | schema::PrimitiveType_BNGrad, | ||||
| schema::PrimitiveType_ActivationGrad, | schema::PrimitiveType_ActivationGrad, | ||||
| schema::PrimitiveType_ApplyMomentum, | schema::PrimitiveType_ApplyMomentum, | ||||
| schema::PrimitiveType_Sgd, | schema::PrimitiveType_Sgd, | ||||
| schema::PrimitiveType_Adam, | |||||
| #endif | #endif | ||||
| schema::PrimitiveType_Conv2D, | schema::PrimitiveType_Conv2D, | ||||
| schema::PrimitiveType_DeConv2D, | schema::PrimitiveType_DeConv2D, | ||||
| @@ -161,6 +161,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| int idx = 0; | int idx = 0; | ||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; | if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; | ||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; | if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; | ||||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | ||||
| @@ -183,7 +183,9 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||||
| auto origin_begin = attr->begin; | auto origin_begin = attr->begin; | ||||
| attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; | attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; | ||||
| auto origin_end = attr->axes; | auto origin_end = attr->axes; | ||||
| attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; | |||||
| if (origin_end.size() >= 4) { | |||||
| attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; | |||||
| } | |||||
| auto origin_stride = attr->size; | auto origin_stride = attr->size; | ||||
| attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; | attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; | ||||
| } | } | ||||
| @@ -122,6 +122,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||||
| param_value->set_format(schema::Format::Format_CKHW); | param_value->set_format(schema::Format::Format_CKHW); | ||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | } else if (op_type == schema::PrimitiveType_DeConv2D) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| #ifdef SUPPORT_TRAIN | |||||
| } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { | |||||
| param_value->set_format(schema::Format::Format_KCHW); | |||||
| } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { | |||||
| param_value->set_format(schema::Format::Format_CKHW); | |||||
| #endif | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | ||||
| << ", node: " << conv_node->fullname_with_scope(); | << ", node: " << conv_node->fullname_with_scope(); | ||||
| @@ -178,6 +184,10 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||||
| auto conv_cnode = node->cast<CNodePtr>(); | auto conv_cnode = node->cast<CNodePtr>(); | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | ||||
| #ifdef SUPPORT_TRAIN | |||||
| ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && | |||||
| ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && | |||||
| #endif | |||||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -18,27 +18,21 @@ | |||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| using mindspore::lite::converter::FmkType_CAFFE; | using mindspore::lite::converter::FmkType_CAFFE; | ||||
| using mindspore::lite::converter::FmkType_TFLITE; | |||||
| using mindspore::lite::converter::FmkType_ONNX; | |||||
| using mindspore::lite::converter::FmkType_MS; | using mindspore::lite::converter::FmkType_MS; | ||||
| using mindspore::schema::QuantType_WeightQuant; | |||||
| using mindspore::schema::QuantType_QUANT_NONE; | |||||
| using mindspore::lite::converter::FmkType_ONNX; | |||||
| using mindspore::lite::converter::FmkType_TFLITE; | |||||
| using mindspore::schema::QuantType_AwareTraining; | using mindspore::schema::QuantType_AwareTraining; | ||||
| using mindspore::schema::QuantType_PostTraining; | using mindspore::schema::QuantType_PostTraining; | ||||
| using mindspore::schema::QuantType_QUANT_NONE; | |||||
| using mindspore::schema::QuantType_WeightQuant; | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kConvWeightIndex = 2; | constexpr size_t kConvWeightIndex = 2; | ||||
| } // namespace | } // namespace | ||||
| void WeightFormatTransformPass::SetQuantType(QuantType type) { | |||||
| this->quant_type = type; | |||||
| } | |||||
| void WeightFormatTransformPass::SetFmkType(FmkType type) { | |||||
| this->fmk_type = type; | |||||
| } | |||||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { | |||||
| this->dst_format = format; | |||||
| } | |||||
| void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; } | |||||
| void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; } | |||||
| lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { | lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| auto node_list = TopoSort(graph->get_return()); | auto node_list = TopoSort(graph->get_return()); | ||||
| @@ -48,6 +42,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||||
| } | } | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | ||||
| #ifdef SUPPORT_TRAIN | |||||
| && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput | |||||
| #endif | |||||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -60,8 +57,8 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||||
| MS_LOG(ERROR) << "weight node must param value"; | MS_LOG(ERROR) << "weight node must param value"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 | |||||
| || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); | |||||
| MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || | |||||
| weight_value->tensor_type() == TypeId::kNumberTypeUInt8); | |||||
| lite::STATUS status; | lite::STATUS status; | ||||
| schema::Format weight_dst_format = schema::Format::Format_KHWC; | schema::Format weight_dst_format = schema::Format::Format_KHWC; | ||||
| if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { | if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { | ||||