diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 286077299d..3fb5214a4e 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -170,6 +170,7 @@ union PrimitiveType { AddFold, SquaredDifference, Flatten, + FlattenGrad, TupleGetItem, Div, Where, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2e13afa463..02496c582a 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -134,7 +134,8 @@ table Minimum { table Flatten { } - +table FlattenGrad { +} table Concat { axis: int; n: int; diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index 0efed3a6ab..a3cb09c40a 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -46,8 +46,6 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vectortype = schema::ActivationType_RELU6; } - auto alpha = GetValue(prim.GetAttr("alpha")); - attr->alpha = alpha; this->primitive_->value.value = attr.release(); if (this->primitive_->value.value == nullptr) { MS_LOG(ERROR) << "new primitiveT value failed"; diff --git a/mindspore/lite/src/ops/apply_momentum.cc b/mindspore/lite/src/ops/apply_momentum.cc index b50716b1eb..31c8aa10ca 100644 --- a/mindspore/lite/src/ops/apply_momentum.cc +++ b/mindspore/lite/src/ops/apply_momentum.cc @@ -19,7 +19,27 @@ namespace lite { #ifdef PRIMITIVE_WRITEABLE - +int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_ApplyMomentum; + } + if (this->primitive_->value.type != schema::PrimitiveType_ApplyMomentum) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + auto attr = std::make_unique(); + this->primitive_->value.value = attr.release(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + return RET_OK; +} #else int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/apply_momentum.h b/mindspore/lite/src/ops/apply_momentum.h index 77ecf588d9..035366a9d0 100644 --- a/mindspore/lite/src/ops/apply_momentum.h +++ b/mindspore/lite/src/ops/apply_momentum.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" @@ -31,6 +32,7 @@ class ApplyMomentum : public PrimitiveC { MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); ApplyMomentum() = default; explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else ApplyMomentum() = default; diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index c3c4ac899b..62231b80dd 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -41,7 +41,6 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - attr->axis = GetValue>(prim.GetAttr("axis")); this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { MS_LOG(ERROR) << "primitive value is nullptr"; diff --git a/mindspore/lite/src/ops/bn_grad.cc b/mindspore/lite/src/ops/bn_grad.cc index 5fa2694e43..18015c523c 100644 --- a/mindspore/lite/src/ops/bn_grad.cc +++ b/mindspore/lite/src/ops/bn_grad.cc @@ -24,7 +24,35 @@ float BNGrad::GetMomentum() const { return this->primitive_->value.AsBNGrad()->m void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; } void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; } - +int BNGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BNGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_BNGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::BNGradInputT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->eps = GetValue(prim.GetAttr("eps")); + attr->momentum = GetValue(prim.GetAttr("momentum")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/bn_grad.h b/mindspore/lite/src/ops/bn_grad.h index 0d09639bfe..fd9baf618e 100644 --- a/mindspore/lite/src/ops/bn_grad.h +++ b/mindspore/lite/src/ops/bn_grad.h @@ -33,6 +33,7 @@ class BNGrad : public PrimitiveC { explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetEps(float eps); void SetMomentum(float momentum); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else BNGrad() = default; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 37d05fe3f3..60e5003145 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -116,8 +116,6 @@ void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema:: channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); } attr->channelMultiplier = channel_mutiplier; - - primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; primitive->value.value = attr.release(); } @@ -168,8 +166,6 @@ void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim, } else { attr->activationType = schema::ActivationType_NO_ACTIVATION; } - - primitive->value.type = schema::PrimitiveType_Conv2D; primitive->value.value = attr.release(); } int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector &inputs) { diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 85a5156e97..4632eacecd 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -114,8 +114,6 @@ void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::P channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); } attr->channelMultiplier = channel_mutiplier; - - primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; primitive->value.value = attr.release(); } @@ -166,8 +164,6 @@ void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim, } else { attr->activationType = schema::ActivationType_NO_ACTIVATION; } - - primitive->value.type = schema::PrimitiveType_Conv2D; primitive->value.value = attr.release(); } int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector &inputs) { diff --git a/mindspore/lite/src/ops/depend.cc b/mindspore/lite/src/ops/depend.cc new file mode 100644 index 0000000000..5176313f63 --- /dev/null +++ b/mindspore/lite/src/ops/depend.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/depend.h" +#include +#include + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Depend::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Depend; + } + if (this->primitive_->value.type != schema::PrimitiveType_Depend) { + MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow)(schema::DependT); + if (attr == nullptr) { + MS_LOG(ERROR) << "attr is nullptr"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/depend.h b/mindspore/lite/src/ops/depend.h new file mode 100644 index 0000000000..01ee755e0e --- /dev/null +++ b/mindspore/lite/src/ops/depend.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_ + +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Depend : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Depend, PrimitiveC); + Depend() = default; + explicit Depend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + Depend() = default; +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_Depend_H_ diff --git a/mindspore/lite/src/ops/flatten_grad.cc b/mindspore/lite/src/ops/flatten_grad.cc new file mode 100644 index 0000000000..a4025efad0 --- /dev/null +++ b/mindspore/lite/src/ops/flatten_grad.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/flatten_grad.h" +#include + +namespace mindspore { +namespace lite { +int FlattenGrad::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive_ != nullptr); + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "FlattenGrad input or output is null!"; + return RET_ERROR; + } + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + + auto input_shape = input->shape(); + std::vector output_shape(2); + output_shape[0] = input_shape[0]; + output_shape[1] = 1; + for (size_t i = 1; i < input_shape.size(); i++) { + output_shape[1] *= input_shape[i]; + } + output->set_shape(output_shape); + return RET_OK; +} +#ifdef PRIMITIVE_WRITEABLE +int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_FlattenGrad; + } + if (this->primitive_->value.type != schema::PrimitiveType_FlattenGrad) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::FlattenGradT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateFlattenGrad(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FlattenGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/flatten_grad.h b/mindspore/lite/src/ops/flatten_grad.h new file mode 100644 index 0000000000..54b61526ba --- /dev/null +++ b/mindspore/lite/src/ops/flatten_grad.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class FlattenGrad : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(FlattenGrad, PrimitiveC); + FlattenGrad() = default; + explicit FlattenGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + FlattenGrad() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index fa52d02c68..8e73840e8b 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -136,7 +136,10 @@ #include "src/ops/power_grad.h" #include "src/ops/softmax_cross_entropy.h" #include "src/ops/bn_grad.h" +#include "src/ops/bn_grad_input.h" #include "src/ops/arithmetic_grad.h" +#include "src/ops/depend.h" +#include "src/ops/flatten_grad.h" #endif @@ -397,6 +400,12 @@ std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &pri return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "PowerGrad") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "SoftmaxCrossEntropyWithLogits") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Depend") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "FlattenGrad") { + return NewPrimitiveC(prim, inputs, quantType); #endif } else { MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; @@ -638,6 +647,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return new PowerGrad(primitive); case schema::PrimitiveType_BNGradInput: return new BNGradInput(primitive); + case schema::PrimitiveType_SoftmaxCrossEntroy: + return new SoftmaxCrossEntroy(primitive); + case schema::PrimitiveType_Depend: + return new Depend(primitive); + case schema::PrimitiveType_FlattenGrad: + return new FlattenGrad(primitive); #endif default: diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index 41e6c1e22f..c764fe62b2 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -24,7 +24,33 @@ std::vector SoftmaxCrossEntropy::GetAxis() const { return this->primitive_- void SoftmaxCrossEntropy::SetAxis(const std::vector &axis) { this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis; } - +int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_SoftmaxCrossEntropy; + } + if (this->primitive_->value.type != schema::PrimitiveType_SoftmaxCrossEntropy) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SoftmaxCrossEntropyT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else std::vector SoftmaxCrossEntropy::GetAxis() const { diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.h b/mindspore/lite/src/ops/softmax_cross_entropy.h index b81a435abe..30cf6cfef3 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.h +++ b/mindspore/lite/src/ops/softmax_cross_entropy.h @@ -33,7 +33,7 @@ class SoftmaxCrossEntropy : public PrimitiveC { SoftmaxCrossEntropy() = default; explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(const std::vector &axis); - + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else SoftmaxCrossEntropy() = default; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index aef14f865c..7504f69411 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -323,6 +323,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.emplace_back(data->value()); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { MS_LOG(DEBUG) << "Value type is ValueSequence."; return RET_OK;