Browse Source

!13171 [lite]adjust sparsesoftmaxentroy and fix mindir bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@jpc_chenjianping
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
29a90cbd2f
19 changed files with 109 additions and 111 deletions
  1. +0
    -43
      mindspore/core/ops/sparse_softmax_cross_entropy.h
  2. +9
    -8
      mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc
  3. +44
    -0
      mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h
  4. +1
    -1
      mindspore/lite/micro/cmake/file_list.cmake
  5. +1
    -1
      mindspore/lite/micro/coder/train.cc
  6. +1
    -1
      mindspore/lite/nnacl/fp32_grad/softmax_grad.h
  7. +4
    -4
      mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c
  8. +6
    -11
      mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h
  9. +2
    -2
      mindspore/lite/schema/ops.fbs
  10. +1
    -0
      mindspore/lite/schema/ops_types.fbs
  11. +3
    -3
      mindspore/lite/src/ops/ops_def.cc
  12. +2
    -2
      mindspore/lite/src/ops/ops_func_declare.h
  13. +4
    -4
      mindspore/lite/src/ops/ops_utils.cc
  14. +2
    -2
      mindspore/lite/src/runtime/infer_manager.cc
  15. +15
    -14
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc
  16. +7
    -6
      mindspore/lite/src/train/train_populate_parameter.cc
  17. +3
    -3
      mindspore/lite/src/train/train_populate_parameter_v0.cc
  18. +1
    -1
      mindspore/lite/src/train/train_session.cc
  19. +3
    -5
      mindspore/lite/tools/anf_exporter/anf_exporter.cc

+ 0
- 43
mindspore/core/ops/sparse_softmax_cross_entropy.h View File

@@ -1,43 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_
#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameSparseSoftmaxCrossEntropy = "SparseSoftmaxCrossEntropy";
class SparseSoftmaxCrossEntropy : public PrimitiveC {
public:
SparseSoftmaxCrossEntropy() : PrimitiveC(kNameSparseSoftmaxCrossEntropy) {}
~SparseSoftmaxCrossEntropy() = default;
MS_DECLARE_PARENT(SparseSoftmaxCrossEntropy, PrimitiveC);
void Init(const bool is_grad = false);
void set_grad(const bool is_grad);
bool get_grad() const;
};
AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSparseSoftmaxCrossEntropyPtr = std::shared_ptr<SparseSoftmaxCrossEntropy>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_

mindspore/core/ops/sparse_softmax_cross_entropy.cc → mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -19,25 +19,26 @@
#include <string>
#include <vector>

#include "ops/sparse_softmax_cross_entropy.h"
#include "ops/sparse_softmax_cross_entropy_with_logits.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
void SparseSoftmaxCrossEntropy::Init(const bool grad) { this->set_grad(grad); }
void SparseSoftmaxCrossEntropyWithLogits::Init(const bool grad) { this->set_grad(grad); }

void SparseSoftmaxCrossEntropy::set_grad(const bool grad) { this->AddAttr(kGrad, MakeValue(grad)); }
void SparseSoftmaxCrossEntropyWithLogits::set_grad(const bool grad) { this->AddAttr(kGrad, MakeValue(grad)); }

bool SparseSoftmaxCrossEntropy::get_grad() const {
bool SparseSoftmaxCrossEntropyWithLogits::get_grad() const {
auto value_ptr = GetAttr(kGrad);
return GetValue<bool>(value_ptr);
}

AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto sparse_softmax_cross_entropy_prim = primitive->cast<PrimSparseSoftmaxCrossEntropyPtr>();
auto sparse_softmax_cross_entropy_prim = primitive->cast<PrimSparseSoftmaxCrossEntropyWithLogitsPtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_cross_entropy_prim);
auto prim_name = sparse_softmax_cross_entropy_prim->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
@@ -57,6 +58,6 @@ AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr
auto output_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(output_type, output_shape);
}
REGISTER_PRIMITIVE_C(kNameSparseSoftmaxCrossEntropy, SparseSoftmaxCrossEntropy);
REGISTER_PRIMITIVE_C(kNameSparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogits);
} // namespace ops
} // namespace mindspore

+ 44
- 0
mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameSparseSoftmaxCrossEntropyWithLogits = "SparseSoftmaxCrossEntropyWithLogits";
class SparseSoftmaxCrossEntropyWithLogits : public PrimitiveC {
public:
SparseSoftmaxCrossEntropyWithLogits() : PrimitiveC(kNameSparseSoftmaxCrossEntropyWithLogits) {}
~SparseSoftmaxCrossEntropyWithLogits() = default;
MS_DECLARE_PARENT(SparseSoftmaxCrossEntropyWithLogits, PrimitiveC);
void Init(const bool is_grad = false);
void set_grad(const bool is_grad);
bool get_grad() const;
};
AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSparseSoftmaxCrossEntropyWithLogitsPtr = std::shared_ptr<SparseSoftmaxCrossEntropyWithLogits>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_

+ 1
- 1
mindspore/lite/micro/cmake/file_list.cmake View File

@@ -281,7 +281,7 @@ set(LITE_KERNEL_SRC
${LITE_DIR}/nnacl/infer/space_to_batch_infer.c
${LITE_DIR}/nnacl/infer/space_to_batch_nd_infer.c
${LITE_DIR}/nnacl/infer/space_to_depth_infer.c
${LITE_DIR}/nnacl/infer/sparse_softmax_cross_entropy_infer.c
${LITE_DIR}/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c
${LITE_DIR}/nnacl/infer/sparse_to_dense_infer.c
${LITE_DIR}/nnacl/infer/split_infer.c
${LITE_DIR}/nnacl/infer/squeeze_infer.c


+ 1
- 1
mindspore/lite/micro/coder/train.cc View File

@@ -56,7 +56,7 @@ std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
}

int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders) {
const std::array<int, 6> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropy,
const std::array<int, 6> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
schema::PrimitiveType_BinaryCrossEntropy,
schema::PrimitiveType_SmoothL1Loss,
schema::PrimitiveType_SmoothL1LossGrad,


+ 1
- 1
mindspore/lite/nnacl/fp32_grad/softmax_grad.h View File

@@ -35,7 +35,7 @@ typedef struct SoftmaxCrossEntropyParameter {
// other parameter
int32_t batch_size_;
unsigned int number_of_classes_;
bool is_grad;
bool is_grad_;
} SoftmaxCrossEntropyParameter;

void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul,


mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.c → mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c View File

@@ -14,10 +14,10 @@
* limitations under the License.
*/

#include "nnacl/infer/sparse_softmax_cross_entropy_infer.h"
#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h"

int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter) {
int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter) {
#ifdef Debug
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
if (check_ret != NNACL_OK) {
@@ -28,7 +28,7 @@ int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inp
const TensorC *in0 = inputs[0];
TensorC *out = outputs[0];

SparseSoftmaxCrossEntropyParameter *param = (SparseSoftmaxCrossEntropyParameter *)parameter;
SoftmaxCrossEntropyParameter *param = (SoftmaxCrossEntropyParameter *)parameter;
if (param->is_grad_ != 0) {
SetShapeTensor(out, in0);
SetDataTypeFormat(out, in0);

mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.h → mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h View File

@@ -13,25 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H
#define MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H
#ifndef MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H
#define MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H

#include "nnacl/infer/common_infer.h"
#include "nnacl/softmax_parameter.h"
#include "nnacl/fp32_grad/softmax_grad.h"

#ifdef __cplusplus
extern "C" {
#endif

typedef struct SparseSoftmaxCrossEntropyParameter {
OpParameter op_parameter_;
bool is_grad_;
} SparseSoftmaxCrossEntropyParameter;

int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter);
int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter);

#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H
#endif // MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H

+ 2
- 2
mindspore/lite/schema/ops.fbs View File

@@ -160,7 +160,7 @@ union PrimitiveType {
SpaceToBatch,
SpaceToBatchND,
SpaceToDepth,
SparseSoftmaxCrossEntropy,
SparseSoftmaxCrossEntropyWithLogits,
SparseToDense,
Split,
Sqrt,
@@ -904,7 +904,7 @@ table SpaceToDepth {
format: Format;
}

table SparseSoftmaxCrossEntropy {
table SparseSoftmaxCrossEntropyWithLogits {
grad: bool;
}



+ 1
- 0
mindspore/lite/schema/ops_types.fbs View File

@@ -52,6 +52,7 @@ enum Format : int {
NC,
NC4,
NC4HW4,
NCDHW,
NUM_OF_FORMAT
}



+ 3
- 3
mindspore/lite/src/ops/ops_def.cc View File

@@ -159,7 +159,7 @@ OP_TYPE(SoftmaxCrossEntropyWithLogits)
OP_TYPE(SpaceToBatch)
OP_TYPE(SpaceToBatchND)
OP_TYPE(SpaceToDepth)
OP_TYPE(SparseSoftmaxCrossEntropy)
OP_TYPE(SparseSoftmaxCrossEntropyWithLogits)
OP_TYPE(SparseToDense)
OP_TYPE(Split)
OP_TYPE(Sqrt)
@@ -903,9 +903,9 @@ OP_ATTR(block_size, long)
OP_ATTR_ENUM(format, Format)
OP_SCHEMA_DEF_END(SpaceToDepth)

OP_SCHEMA_DEF(SparseSoftmaxCrossEntropy)
OP_SCHEMA_DEF(SparseSoftmaxCrossEntropyWithLogits)
OP_ATTR(grad, bool)
OP_SCHEMA_DEF_END(SparseSoftmaxCrossEntropy)
OP_SCHEMA_DEF_END(SparseSoftmaxCrossEntropyWithLogits)

OP_SCHEMA_DEF(SparseToDense)
OP_SCHEMA_DEF_END(SparseToDense)


+ 2
- 2
mindspore/lite/src/ops/ops_func_declare.h View File

@@ -149,7 +149,7 @@
#include "ops/space_to_batch.h"
#include "ops/space_to_batch_nd.h"
#include "ops/space_to_depth.h"
#include "ops/sparse_softmax_cross_entropy.h"
#include "ops/sparse_softmax_cross_entropy_with_logits.h"
#include "ops/sparse_to_dense.h"
#include "ops/split.h"
#include "ops/square.h"
@@ -405,7 +405,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(SoftmaxCrossEntropyWithLogits);
FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToBatch);
FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToBatchND);
FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToDepth);
FUNC_MSOP2SCHEMAOP_DECLARE(SparseSoftmaxCrossEntropy);
FUNC_MSOP2SCHEMAOP_DECLARE(SparseSoftmaxCrossEntropyWithLogits);
FUNC_MSOP2SCHEMAOP_DECLARE(SparseToDense);
FUNC_MSOP2SCHEMAOP_DECLARE(Split);
FUNC_MSOP2SCHEMAOP_DECLARE(Sqrt);


+ 4
- 4
mindspore/lite/src/ops/ops_utils.cc View File

@@ -612,8 +612,8 @@ schema::PrimitiveT *SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SpaceToDepth>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SparseSoftmaxCrossEntropyPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseSoftmaxCrossEntropy>>(node);
schema::PrimitiveT *SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseSoftmaxCrossEntropyWithLogits>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) {
@@ -886,8 +886,8 @@ RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCr
RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator);
RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator);
RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator);
RegistryMSOps g_sparseSoftmaxCrossEntropyPrimitiveCreatorRegistry("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyPrimitiveCreator);
RegistryMSOps g_sparseSoftmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry(
"SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator);
RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator);
RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator);
RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator);


+ 2
- 2
mindspore/lite/src/runtime/infer_manager.cc View File

@@ -120,7 +120,7 @@
#include "nnacl/infer/merge_infer.h"
#include "nnacl/infer/switch_infer.h"
#include "nnacl/infer/assert_op_infer.h"
#include "nnacl/infer/sparse_softmax_cross_entropy_infer.h"
#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h"
#include "nnacl/infer/dropout_infer.h"
#include "nnacl/infer/prior_box_infer.h"

@@ -394,7 +394,7 @@ static RegistryInferShape g_MergeInferShape(mindspore::schema::PrimitiveType_Mer
static RegistryInferShape g_SwitchInferShape(mindspore::schema::PrimitiveType_Switch, SwitchInferShape);
static RegistryInferShape g_AssertOpInferShape(mindspore::schema::PrimitiveType_Assert, AssertOpInferShape);
static RegistryInferShape g_SparseSoftmaxCrossEntropyInferShape(
mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropy, SparseSoftmaxCrossEntropyInferShape);
mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogitsInferShape);
static RegistryInferShape g_DropoutInferShape(mindspore::schema::PrimitiveType_Dropout, DropoutInferShape);
static RegistryInferShape g_PriorBoxInferShape(mindspore::schema::PrimitiveType_PriorBox, PriorBoxInferShape);
static RegistryInferShape g_MinimumGradInferShape(mindspore::schema::PrimitiveType_MinimumGrad, MaximumGradInferShape);


+ 15
- 14
mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -25,7 +25,7 @@
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropy;
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;

namespace mindspore::kernel {

@@ -93,7 +93,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) {
std::fill(losses_, losses_ + data_size, 0.f);
std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0.f);
Softmax(ins, losses_, sum_data_, &sm_params_);
if (sce_param->is_grad) {
if (sce_param->is_grad_) {
GradPostExecute(labels, losses_, out);
} else {
ForwardPostExecute(labels, losses_, out);
@@ -101,20 +101,21 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) {
return RET_OK;
}

int SparseSoftmaxCrossEntropyRun(void *cdata, int task_id) {
int SparseSoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id) {
auto sparse_kernel = reinterpret_cast<SparseSoftmaxCrossEntropyWithLogitsCPUKernel *>(cdata);
auto error_code = sparse_kernel->Execute(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyRun error task_id[" << task_id << "] error_code[" << error_code << "]";
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsRun error task_id[" << task_id << "] error_code[" << error_code
<< "]";
return RET_ERROR;
}
return RET_OK;
}

int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1);
int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyWithLogitsRun, this, 1);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]";
MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogits function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
@@ -145,13 +146,13 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() {
return RET_OK;
}

kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter,
const lite::InnerContext *ctx,
const kernel::KernelKey &desc) {
kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyWithLogitsFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter,
const lite::InnerContext *ctx,
const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_SparseSoftmaxCrossEntropy);
MS_ASSERT(desc.type == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits);
auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SparseSoftmaxCrossEntropyWithLogitsCPUKernel failed!";
@@ -167,6 +168,6 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator(const std::vec
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropy,
CpuSparseSoftmaxCrossEntropyFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
CpuSparseSoftmaxCrossEntropyWithLogitsFp32KernelCreator)
} // namespace mindspore::kernel

+ 7
- 6
mindspore/lite/src/train/train_populate_parameter.cc View File

@@ -123,7 +123,7 @@ OpParameter *PopulateSgdParameter(const void *prim) {
return reinterpret_cast<OpParameter *>(p);
}

OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *prim) {
OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) {
SoftmaxCrossEntropyParameter *sce_param =
reinterpret_cast<SoftmaxCrossEntropyParameter *>(malloc(sizeof(SoftmaxCrossEntropyParameter)));
if (sce_param == nullptr) {
@@ -131,9 +131,9 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *prim) {
return nullptr;
}
auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SparseSoftmaxCrossEntropy();
auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
sce_param->op_parameter_.type_ = primitive->value_type();
sce_param->is_grad = value->grad();
sce_param->is_grad_ = value->grad();
return reinterpret_cast<OpParameter *>(sce_param);
}

@@ -146,7 +146,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) {
}
auto primitive = static_cast<const schema::Primitive *>(prim);
sce_param->op_parameter_.type_ = primitive->value_type();
sce_param->is_grad = 0;
sce_param->is_grad_ = 0;
return reinterpret_cast<OpParameter *>(sce_param);
}

@@ -385,8 +385,9 @@ void PopulateTrainParameters() {
lite::SCHEMA_CUR);
lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits,
PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(
schema::PrimitiveType_SparseSoftmaxCrossEntropy, PopulateSparseSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR);
lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
PopulateSparseSoftmaxCrossEntropyWithLogitsParameter,
lite::SCHEMA_CUR);
lite::Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter,
lite::SCHEMA_CUR);
lite::Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter,


+ 3
- 3
mindspore/lite/src/train/train_populate_parameter_v0.cc View File

@@ -229,9 +229,9 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *primitive) {
}
auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy();

sce_param->is_grad = sparseSoftmaxCrossEntropy_prim->isGrad();
sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad();

sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropy;
sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
return reinterpret_cast<OpParameter *>(sce_param);
}

@@ -246,7 +246,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *primitive) {
MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed.";
return nullptr;
}
sce_param->is_grad = 0;
sce_param->is_grad_ = 0;
sce_param->op_parameter_.type_ = schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
return reinterpret_cast<OpParameter *>(sce_param);
}


+ 1
- 1
mindspore/lite/src/train/train_session.cc View File

@@ -454,7 +454,7 @@ int TrainSession::OptimizerStep() {

bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const {
return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropyWithLogits ||
kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy ||
kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits ||
kernel->Type() == schema::PrimitiveType_SmoothL1Loss ||
kernel->Type() == schema::PrimitiveType_SmoothL1LossGrad ||
kernel->Type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits ||


+ 3
- 5
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -356,11 +356,9 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
}

RemoveIfMakeTuple(cnode);
if (train_flag) {
RemoveIfDepend(cnode);
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) {
continue;
}
RemoveIfDepend(cnode);
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) {
continue;
}

if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {


Loading…
Cancel
Save