Browse Source

sync code to support train

tags/v1.1.0
yangjie159 5 years ago
parent
commit
4cecdd8508
21 changed files with 581 additions and 5 deletions
  1. +60
    -0
      mindspore/lite/nnacl/fp32_grad/binary_cross_entropy.c
  2. +36
    -0
      mindspore/lite/nnacl/fp32_grad/binary_cross_entropy.h
  3. +42
    -0
      mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c
  4. +36
    -0
      mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.h
  5. +11
    -0
      mindspore/lite/src/ops/assign_add.cc
  6. +9
    -0
      mindspore/lite/src/ops/binary_cross_entropy.cc
  7. +9
    -0
      mindspore/lite/src/ops/binary_cross_entropy_grad.cc
  8. +1
    -1
      mindspore/lite/src/ops/cast.cc
  9. +10
    -4
      mindspore/lite/src/ops/gather.cc
  10. +9
    -0
      mindspore/lite/src/ops/oneslike.cc
  11. +41
    -0
      mindspore/lite/src/ops/populate/activation_grad_populate.cc
  12. +36
    -0
      mindspore/lite/src/ops/populate/adam_populate.cc
  13. +36
    -0
      mindspore/lite/src/ops/populate/assign_add_populate.cc
  14. +36
    -0
      mindspore/lite/src/ops/populate/assign_populate.cc
  15. +37
    -0
      mindspore/lite/src/ops/populate/bias_grad_populate.cc
  16. +42
    -0
      mindspore/lite/src/ops/populate/binary_cross_entropy_grad_populate.cc
  17. +42
    -0
      mindspore/lite/src/ops/populate/binary_cross_entropy_populate.cc
  18. +36
    -0
      mindspore/lite/src/ops/populate/oneslike_populate.cc
  19. +37
    -0
      mindspore/lite/src/ops/populate/unsorted_segment_sum_populate.cc
  20. +6
    -0
      mindspore/lite/src/ops/slice.cc
  21. +9
    -0
      mindspore/lite/src/ops/unsorted_segment_sum.cc

+ 60
- 0
mindspore/lite/nnacl/fp32_grad/binary_cross_entropy.c View File

@@ -0,0 +1,60 @@
/*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <math.h>
#include "nnacl/fp32_grad/binary_cross_entropy.h"

static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x,
const float *input_y, const float *weight, float *loss, float *tmp_loss) {
float epsilon = 1e-12;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float value =
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
loss[i] = value;
}
} else {
for (int i = 0; i < input_size; i++) {
float value =
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
tmp_loss[i] = value;
}
}
}

void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, float *loss, float *tmp_loss) {
loss[0] = 0.0f;
BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss);
if (reduction != 0) {
if (input_size % 2 == 1) {
tmp_loss[0] += tmp_loss[input_size - 1];
}
for (int stride = input_size / 2; stride > 0; stride >>= 1) {
for (int i = 0; i < stride; i++) {
tmp_loss[i] += tmp_loss[i + stride];
}

if (stride > 2 && stride % 2 == 1) {
tmp_loss[0] += tmp_loss[stride - 1];
}
}
loss[0] += tmp_loss[0];
if (reduction == 1) {
loss[0] /= input_size;
}
}
}

+ 36
- 0
mindspore/lite/nnacl/fp32_grad/binary_cross_entropy.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_
#define MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_

#include "nnacl/op_base.h"

typedef struct BinaryCrossEntropyParameter {
OpParameter op_parameter_;
int reduction;
} BinaryCrossEntropyParameter;

#ifdef __cplusplus
extern "C" {
#endif

void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, float *loss, float *tmp_loss);

#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_

+ 42
- 0
mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c View File

@@ -0,0 +1,42 @@
/*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))

int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx) {
float epsilon = 1e-12;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
} else {
float dloss1 = dloss[0];
if (reduction == 1) {
dloss1 = dloss[0] / input_size;
}
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
}
return 0;
}

+ 36
- 0
mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_
#define MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_

#include "nnacl/op_base.h"

typedef struct BinaryCrossEntropyGradParameter {
OpParameter op_parameter_;
int reduction;
} BinaryCrossEntropyGradParameter;

#ifdef __cplusplus
extern "C" {
#endif

int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx);

#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_

+ 11
- 0
mindspore/lite/src/ops/assign_add.cc View File

@@ -15,6 +15,11 @@
*/

#include "src/ops/assign_add.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -58,7 +63,13 @@ int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *AssignAddCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<AssignAdd>(primitive);
}
Registry AssignAddRegistry(schema::PrimitiveType_AssignAdd, AssignAddCreator);
#endif

int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *y = inputs_[1];


+ 9
- 0
mindspore/lite/src/ops/binary_cross_entropy.cc View File

@@ -17,6 +17,10 @@
#include <string>
#include "src/ops/binary_cross_entropy.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -85,6 +89,11 @@ int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive,
}

int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); }

PrimitiveC *BinaryCrossEntropyCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BinaryCrossEntropy>(primitive);
}
Registry BinaryCrossEntropyRegistry(schema::PrimitiveType_BinaryCrossEntropy, BinaryCrossEntropyCreator);
#endif
int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];


+ 9
- 0
mindspore/lite/src/ops/binary_cross_entropy_grad.cc View File

@@ -17,6 +17,10 @@
#include <string>
#include "src/ops/binary_cross_entropy_grad.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -92,6 +96,11 @@ int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primiti
int BinaryCrossEntropyGrad::GetReduction() const {
return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction();
}

PrimitiveC *BinaryCrossEntropyGradCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BinaryCrossEntropyGrad>(primitive);
}
Registry BinaryCrossEntropyGradRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad, BinaryCrossEntropyGradCreator);
#endif
int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];


+ 1
- 1
mindspore/lite/src/ops/cast.cc View File

@@ -89,7 +89,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
return RET_INPUT_TENSOR_ERROR;
}


+ 10
- 4
mindspore/lite/src/ops/gather.cc View File

@@ -54,8 +54,15 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
delete gather_attr;
return RET_ERROR;
}
gather_attr->axis = GetValue<int>(prim.GetAttr("axis"));
gather_attr->batchDims = GetValue<int>(prim.GetAttr("batchDims"));
if (inputs[2]->isa<ValueNode>()) {
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
int axis = GetValue<int>(axis_tensor->value());
gather_attr->axis = axis;
} else {
MS_LOG(ERROR) << "input axis is not value node.";
return RET_ERROR;
}
gather_attr->batchDims = 0;
this->primitive_->value.value = gather_attr;
}
return RET_OK;
@@ -85,8 +92,7 @@ Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator);
int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "Gather should have two inputs";
return RET_INPUT_TENSOR_ERROR;
MS_LOG(DEBUG) << "Gather should have two inputs";
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Gather should have one outputs";


+ 9
- 0
mindspore/lite/src/ops/oneslike.cc View File

@@ -16,6 +16,10 @@

#include "src/ops/oneslike.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -59,6 +63,11 @@ int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *OnesLikeCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<OnesLike>(primitive);
}
Registry OnesLikeRegistry(schema::PrimitiveType_OnesLike, OnesLikeCreator);
#endif
int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];


+ 41
- 0
mindspore/lite/src/ops/populate/activation_grad_populate.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/activation_grad.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/activation_grad.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateActivationGradParameter(const mindspore::lite::PrimitiveC *primitive) {
ActivationGradParameter *act_param =
reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationGradParameter));
act_param->op_parameter.type_ = primitive->Type();
auto activation =
reinterpret_cast<mindspore::lite::ActivationGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->type_ = static_cast<int>(activation->GetType());
act_param->alpha_ = activation->GetAlpha();
return reinterpret_cast<OpParameter *>(act_param);
}
Registry ActivationGradParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter);
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/src/ops/populate/adam_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/adam.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc Adam Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter);
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/src/ops/populate/assign_add_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/assign_add.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateAssignAddParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc AssignAdd Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, PopulateAssignAddParameter);
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/src/ops/populate/assign_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/assign.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateAssignParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc Assign Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

Registry AssignParameterRegistry(schema::PrimitiveType_Assign, PopulateAssignParameter);
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/src/ops/populate/bias_grad_populate.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateBiasGradParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();

return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry PopulateBiasGradParameterParameterRegistry(schema::PrimitiveType_BiasGrad, PopulateBiasGradParameter);

} // namespace lite
} // namespace mindspore

+ 42
- 0
mindspore/lite/src/ops/populate/binary_cross_entropy_grad_populate.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/binary_cross_entropy_grad.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyGradParameter(const mindspore::lite::PrimitiveC *primitive) {
BinaryCrossEntropyGradParameter *bce_param =
reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyGradParameter));
bce_param->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropyGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
bce_param->reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(bce_param);
}

Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
PopulateBinaryCrossEntropyGradParameter);
} // namespace lite
} // namespace mindspore

+ 42
- 0
mindspore/lite/src/ops/populate/binary_cross_entropy_populate.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/binary_cross_entropy.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) {
BinaryCrossEntropyParameter *bce_param =
reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyParameter));
bce_param->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropy *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
bce_param->reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(bce_param);
}

Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy,
PopulateBinaryCrossEntropyParameter);
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/src/ops/populate/oneslike_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/oneslike.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateOnesLikeParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OnesLike Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, PopulateOnesLikeParameter);
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/src/ops/populate/unsorted_segment_sum_populate.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/unsorted_segment_sum.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateUnsortedSegmentSumParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc UnsortedSegmentSum Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}

Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum,
PopulateUnsortedSegmentSumParameter);
} // namespace lite
} // namespace mindspore

+ 6
- 0
mindspore/lite/src/ops/slice.cc View File

@@ -94,6 +94,12 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
}
}
}
std::vector<int> axes;
axes.clear();
for (size_t i = 0; i < attr->begin.size(); i++) {
axes.push_back(i);
}
attr->axes = axes;
}
this->primitive_->value.value = attr;
}


+ 9
- 0
mindspore/lite/src/ops/unsorted_segment_sum.cc View File

@@ -17,6 +17,10 @@
#include <memory>
#include "src/ops/unsorted_segment_sum.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@@ -69,6 +73,11 @@ int UnsortedSegmentSum::GetNumSegments() const {
int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments();
return ret;
}

PrimitiveC *UnsortedSegmentSumCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<UnsortedSegmentSum>(primitive);
}
Registry UnsortedSegmentSumRegistry(schema::PrimitiveType_UnsortedSegmentSum, UnsortedSegmentSumCreator);
#endif
int UnsortedSegmentSum::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
// check inputs and outputs


Loading…
Cancel
Save