diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy.h b/mindspore/core/ops/sparse_softmax_cross_entropy.h deleted file mode 100644 index fa4b94c9f9..0000000000 --- a/mindspore/core/ops/sparse_softmax_cross_entropy.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ -#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ -#include -#include -#include "ops/primitive_c.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameSparseSoftmaxCrossEntropy = "SparseSoftmaxCrossEntropy"; -class SparseSoftmaxCrossEntropy : public PrimitiveC { - public: - SparseSoftmaxCrossEntropy() : PrimitiveC(kNameSparseSoftmaxCrossEntropy) {} - ~SparseSoftmaxCrossEntropy() = default; - MS_DECLARE_PARENT(SparseSoftmaxCrossEntropy, PrimitiveC); - void Init(const bool is_grad = false); - void set_grad(const bool is_grad); - bool get_grad() const; -}; -AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimSparseSoftmaxCrossEntropyPtr = std::shared_ptr; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy.cc b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc similarity index 70% rename from mindspore/core/ops/sparse_softmax_cross_entropy.cc rename to mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc index 5d0a24b2c8..cd39cf856b 100644 --- a/mindspore/core/ops/sparse_softmax_cross_entropy.cc +++ b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc @@ -19,25 +19,26 @@ #include #include -#include "ops/sparse_softmax_cross_entropy.h" +#include "ops/sparse_softmax_cross_entropy_with_logits.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" namespace mindspore { namespace ops { -void SparseSoftmaxCrossEntropy::Init(const bool grad) { this->set_grad(grad); } +void SparseSoftmaxCrossEntropyWithLogits::Init(const bool grad) { this->set_grad(grad); } -void SparseSoftmaxCrossEntropy::set_grad(const bool grad) { this->AddAttr(kGrad, MakeValue(grad)); } +void SparseSoftmaxCrossEntropyWithLogits::set_grad(const bool grad) { this->AddAttr(kGrad, MakeValue(grad)); } -bool SparseSoftmaxCrossEntropy::get_grad() const { +bool SparseSoftmaxCrossEntropyWithLogits::get_grad() const { auto value_ptr = GetAttr(kGrad); return GetValue(value_ptr); } -AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto sparse_softmax_cross_entropy_prim = primitive->cast(); + auto sparse_softmax_cross_entropy_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(sparse_softmax_cross_entropy_prim); auto prim_name = sparse_softmax_cross_entropy_prim->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); @@ -57,6 +58,6 @@ AbstractBasePtr SparseSoftmaxCrossEntropyInfer(const abstract::AnalysisEnginePtr auto output_type = input_args[0]->BuildType()->cast()->element(); return std::make_shared(output_type, output_shape); } -REGISTER_PRIMITIVE_C(kNameSparseSoftmaxCrossEntropy, SparseSoftmaxCrossEntropy); +REGISTER_PRIMITIVE_C(kNameSparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogits); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h new file mode 100644 index 0000000000..fe0d872f9e --- /dev/null +++ b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSoftmaxCrossEntropyWithLogits = "SparseSoftmaxCrossEntropyWithLogits"; +class SparseSoftmaxCrossEntropyWithLogits : public PrimitiveC { + public: + SparseSoftmaxCrossEntropyWithLogits() : PrimitiveC(kNameSparseSoftmaxCrossEntropyWithLogits) {} + ~SparseSoftmaxCrossEntropyWithLogits() = default; + MS_DECLARE_PARENT(SparseSoftmaxCrossEntropyWithLogits, PrimitiveC); + void Init(const bool is_grad = false); + void set_grad(const bool is_grad); + bool get_grad() const; +}; +AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSoftmaxCrossEntropyWithLogitsPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index ed24e68af6..697a174713 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -281,7 +281,7 @@ set(LITE_KERNEL_SRC ${LITE_DIR}/nnacl/infer/space_to_batch_infer.c ${LITE_DIR}/nnacl/infer/space_to_batch_nd_infer.c ${LITE_DIR}/nnacl/infer/space_to_depth_infer.c - ${LITE_DIR}/nnacl/infer/sparse_softmax_cross_entropy_infer.c + ${LITE_DIR}/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c ${LITE_DIR}/nnacl/infer/sparse_to_dense_infer.c ${LITE_DIR}/nnacl/infer/split_infer.c ${LITE_DIR}/nnacl/infer/squeeze_infer.c diff --git a/mindspore/lite/micro/coder/train.cc b/mindspore/lite/micro/coder/train.cc index f566e2c4a4..4fe0a06b81 100644 --- a/mindspore/lite/micro/coder/train.cc +++ b/mindspore/lite/micro/coder/train.cc @@ -56,7 +56,7 @@ std::set FindInferenceOpcoders(OperatorCoder *edge) { } int Train::TransformGraphForTrain(CoderContext *context, const std::vector> &op_coders) { - const std::array loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropy, + const std::array loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, schema::PrimitiveType_BinaryCrossEntropy, schema::PrimitiveType_SmoothL1Loss, schema::PrimitiveType_SmoothL1LossGrad, diff --git a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h index 85f4717b64..d70a70730d 100644 --- a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h @@ -35,7 +35,7 @@ typedef struct SoftmaxCrossEntropyParameter { // other parameter int32_t batch_size_; unsigned int number_of_classes_; - bool is_grad; + bool is_grad_; } SoftmaxCrossEntropyParameter; void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, diff --git a/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.c b/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c similarity index 74% rename from mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.c rename to mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c index 6382a5188e..fc309c9a2e 100644 --- a/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.c +++ b/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "nnacl/infer/sparse_softmax_cross_entropy_infer.h" +#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h" -int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, - size_t outputs_size, OpParameter *parameter) { +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { #ifdef Debug int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); if (check_ret != NNACL_OK) { @@ -28,7 +28,7 @@ int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inp const TensorC *in0 = inputs[0]; TensorC *out = outputs[0]; - SparseSoftmaxCrossEntropyParameter *param = (SparseSoftmaxCrossEntropyParameter *)parameter; + SoftmaxCrossEntropyParameter *param = (SoftmaxCrossEntropyParameter *)parameter; if (param->is_grad_ != 0) { SetShapeTensor(out, in0); SetDataTypeFormat(out, in0); diff --git a/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.h b/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h similarity index 55% rename from mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.h rename to mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h index 56322e3533..f32cda9f45 100644 --- a/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_infer.h +++ b/mindspore/lite/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h @@ -13,25 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H -#define MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H +#ifndef MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H +#define MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H #include "nnacl/infer/common_infer.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl/fp32_grad/softmax_grad.h" #ifdef __cplusplus extern "C" { #endif -typedef struct SparseSoftmaxCrossEntropyParameter { - OpParameter op_parameter_; - bool is_grad_; -} SparseSoftmaxCrossEntropyParameter; - -int SparseSoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, - size_t outputs_size, OpParameter *parameter); +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_INFER_H +#endif // MINDSPORE_LITE_NNACL_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a7f6a930e7..28c8926d29 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -160,7 +160,7 @@ union PrimitiveType { SpaceToBatch, SpaceToBatchND, SpaceToDepth, - SparseSoftmaxCrossEntropy, + SparseSoftmaxCrossEntropyWithLogits, SparseToDense, Split, Sqrt, @@ -904,7 +904,7 @@ table SpaceToDepth { format: Format; } -table SparseSoftmaxCrossEntropy { +table SparseSoftmaxCrossEntropyWithLogits { grad: bool; } diff --git a/mindspore/lite/schema/ops_types.fbs b/mindspore/lite/schema/ops_types.fbs index 8d642dbb28..41ff6ace21 100644 --- a/mindspore/lite/schema/ops_types.fbs +++ b/mindspore/lite/schema/ops_types.fbs @@ -52,6 +52,7 @@ enum Format : int { NC, NC4, NC4HW4, + NCDHW, NUM_OF_FORMAT } diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index b79106711f..bfe9b33940 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -159,7 +159,7 @@ OP_TYPE(SoftmaxCrossEntropyWithLogits) OP_TYPE(SpaceToBatch) OP_TYPE(SpaceToBatchND) OP_TYPE(SpaceToDepth) -OP_TYPE(SparseSoftmaxCrossEntropy) +OP_TYPE(SparseSoftmaxCrossEntropyWithLogits) OP_TYPE(SparseToDense) OP_TYPE(Split) OP_TYPE(Sqrt) @@ -903,9 +903,9 @@ OP_ATTR(block_size, long) OP_ATTR_ENUM(format, Format) OP_SCHEMA_DEF_END(SpaceToDepth) -OP_SCHEMA_DEF(SparseSoftmaxCrossEntropy) +OP_SCHEMA_DEF(SparseSoftmaxCrossEntropyWithLogits) OP_ATTR(grad, bool) -OP_SCHEMA_DEF_END(SparseSoftmaxCrossEntropy) +OP_SCHEMA_DEF_END(SparseSoftmaxCrossEntropyWithLogits) OP_SCHEMA_DEF(SparseToDense) OP_SCHEMA_DEF_END(SparseToDense) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index d0525d3744..b66d417f22 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -149,7 +149,7 @@ #include "ops/space_to_batch.h" #include "ops/space_to_batch_nd.h" #include "ops/space_to_depth.h" -#include "ops/sparse_softmax_cross_entropy.h" +#include "ops/sparse_softmax_cross_entropy_with_logits.h" #include "ops/sparse_to_dense.h" #include "ops/split.h" #include "ops/square.h" @@ -405,7 +405,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(SoftmaxCrossEntropyWithLogits); FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToBatch); FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToBatchND); FUNC_MSOP2SCHEMAOP_DECLARE(SpaceToDepth); -FUNC_MSOP2SCHEMAOP_DECLARE(SparseSoftmaxCrossEntropy); +FUNC_MSOP2SCHEMAOP_DECLARE(SparseSoftmaxCrossEntropyWithLogits); FUNC_MSOP2SCHEMAOP_DECLARE(SparseToDense); FUNC_MSOP2SCHEMAOP_DECLARE(Split); FUNC_MSOP2SCHEMAOP_DECLARE(Sqrt); diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index f7bfe6f136..0cbfa7e7db 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -612,8 +612,8 @@ schema::PrimitiveT *SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } -schema::PrimitiveT *SparseSoftmaxCrossEntropyPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetValueNode>(node); +schema::PrimitiveT *SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) { @@ -886,8 +886,8 @@ RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCr RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); -RegistryMSOps g_sparseSoftmaxCrossEntropyPrimitiveCreatorRegistry("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyPrimitiveCreator); +RegistryMSOps g_sparseSoftmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry( + "SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator); RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); diff --git a/mindspore/lite/src/runtime/infer_manager.cc b/mindspore/lite/src/runtime/infer_manager.cc index 3541515dce..d29b326906 100644 --- a/mindspore/lite/src/runtime/infer_manager.cc +++ b/mindspore/lite/src/runtime/infer_manager.cc @@ -120,7 +120,7 @@ #include "nnacl/infer/merge_infer.h" #include "nnacl/infer/switch_infer.h" #include "nnacl/infer/assert_op_infer.h" -#include "nnacl/infer/sparse_softmax_cross_entropy_infer.h" +#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h" #include "nnacl/infer/dropout_infer.h" #include "nnacl/infer/prior_box_infer.h" @@ -394,7 +394,7 @@ static RegistryInferShape g_MergeInferShape(mindspore::schema::PrimitiveType_Mer static RegistryInferShape g_SwitchInferShape(mindspore::schema::PrimitiveType_Switch, SwitchInferShape); static RegistryInferShape g_AssertOpInferShape(mindspore::schema::PrimitiveType_Assert, AssertOpInferShape); static RegistryInferShape g_SparseSoftmaxCrossEntropyInferShape( - mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropy, SparseSoftmaxCrossEntropyInferShape); + mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogitsInferShape); static RegistryInferShape g_DropoutInferShape(mindspore::schema::PrimitiveType_Dropout, DropoutInferShape); static RegistryInferShape g_PriorBoxInferShape(mindspore::schema::PrimitiveType_PriorBox, PriorBoxInferShape); static RegistryInferShape g_MinimumGradInferShape(mindspore::schema::PrimitiveType_MinimumGrad, MaximumGradInferShape); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index fff94b7749..794c5562ce 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -25,7 +25,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropy; +using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; namespace mindspore::kernel { @@ -93,7 +93,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { std::fill(losses_, losses_ + data_size, 0.f); std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0.f); Softmax(ins, losses_, sum_data_, &sm_params_); - if (sce_param->is_grad) { + if (sce_param->is_grad_) { GradPostExecute(labels, losses_, out); } else { ForwardPostExecute(labels, losses_, out); @@ -101,20 +101,21 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { return RET_OK; } -int SparseSoftmaxCrossEntropyRun(void *cdata, int task_id) { +int SparseSoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id) { auto sparse_kernel = reinterpret_cast(cdata); auto error_code = sparse_kernel->Execute(task_id); if (error_code != RET_OK) { - MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsRun error task_id[" << task_id << "] error_code[" << error_code + << "]"; return RET_ERROR; } return RET_OK; } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, SparseSoftmaxCrossEntropyWithLogitsRun, this, 1); if (error_code != RET_OK) { - MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy function error error_code[" << error_code << "]"; + MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogits function error error_code[" << error_code << "]"; return RET_ERROR; } return RET_OK; @@ -145,13 +146,13 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { return RET_OK; } -kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, - const lite::InnerContext *ctx, - const kernel::KernelKey &desc) { +kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyWithLogitsFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::InnerContext *ctx, + const kernel::KernelKey &desc) { MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_SparseSoftmaxCrossEntropy); + MS_ASSERT(desc.type == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits); auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "new SparseSoftmaxCrossEntropyWithLogitsCPUKernel failed!"; @@ -167,6 +168,6 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator(const std::vec } return kernel; } -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropy, - CpuSparseSoftmaxCrossEntropyFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, + CpuSparseSoftmaxCrossEntropyWithLogitsFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index ebeef7d222..6dae3fef33 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -123,7 +123,7 @@ OpParameter *PopulateSgdParameter(const void *prim) { return reinterpret_cast(p); } -OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *prim) { +OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *prim) { SoftmaxCrossEntropyParameter *sce_param = reinterpret_cast(malloc(sizeof(SoftmaxCrossEntropyParameter))); if (sce_param == nullptr) { @@ -131,9 +131,9 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *prim) { return nullptr; } auto primitive = static_cast(prim); - auto value = primitive->value_as_SparseSoftmaxCrossEntropy(); + auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits(); sce_param->op_parameter_.type_ = primitive->value_type(); - sce_param->is_grad = value->grad(); + sce_param->is_grad_ = value->grad(); return reinterpret_cast(sce_param); } @@ -146,7 +146,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *prim) { } auto primitive = static_cast(prim); sce_param->op_parameter_.type_ = primitive->value_type(); - sce_param->is_grad = 0; + sce_param->is_grad_ = 0; return reinterpret_cast(sce_param); } @@ -385,8 +385,9 @@ void PopulateTrainParameters() { lite::SCHEMA_CUR); lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits, PopulateSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR); - lite::Registry SparseSoftmaxCrossEntropyParameterRegistry( - schema::PrimitiveType_SparseSoftmaxCrossEntropy, PopulateSparseSoftmaxCrossEntropyParameter, lite::SCHEMA_CUR); + lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, + PopulateSparseSoftmaxCrossEntropyWithLogitsParameter, + lite::SCHEMA_CUR); lite::Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter, lite::SCHEMA_CUR); lite::Registry DependParameterRegistry(schema::PrimitiveType_Depend, lite::DefaultPopulateParameter, diff --git a/mindspore/lite/src/train/train_populate_parameter_v0.cc b/mindspore/lite/src/train/train_populate_parameter_v0.cc index e9497472fa..5e67e385f2 100644 --- a/mindspore/lite/src/train/train_populate_parameter_v0.cc +++ b/mindspore/lite/src/train/train_populate_parameter_v0.cc @@ -229,9 +229,9 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *primitive) { } auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy(); - sce_param->is_grad = sparseSoftmaxCrossEntropy_prim->isGrad(); + sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad(); - sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropy; + sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; return reinterpret_cast(sce_param); } @@ -246,7 +246,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const void *primitive) { MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; return nullptr; } - sce_param->is_grad = 0; + sce_param->is_grad_ = 0; sce_param->op_parameter_.type_ = schema::PrimitiveType_SoftmaxCrossEntropyWithLogits; return reinterpret_cast(sce_param); } diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index badd1fdd34..642b586efc 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -454,7 +454,7 @@ int TrainSession::OptimizerStep() { bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const { return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropyWithLogits || - kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy || + kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits || kernel->Type() == schema::PrimitiveType_SmoothL1Loss || kernel->Type() == schema::PrimitiveType_SmoothL1LossGrad || kernel->Type() == schema::PrimitiveType_SigmoidCrossEntropyWithLogits || diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 13301cadef..0d9c808ec5 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -356,11 +356,9 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { - continue; - } + RemoveIfDepend(cnode); + if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { + continue; } if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {