| @@ -170,6 +170,7 @@ union PrimitiveType { | |||||
| AddFold, | AddFold, | ||||
| SquaredDifference, | SquaredDifference, | ||||
| Flatten, | Flatten, | ||||
| FlattenGrad, | |||||
| TupleGetItem, | TupleGetItem, | ||||
| Div, | Div, | ||||
| Where, | Where, | ||||
| @@ -134,7 +134,8 @@ table Minimum { | |||||
| table Flatten { | table Flatten { | ||||
| } | } | ||||
| table FlattenGrad { | |||||
| } | |||||
| table Concat { | table Concat { | ||||
| axis: int; | axis: int; | ||||
| n: int; | n: int; | ||||
| @@ -46,8 +46,6 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP | |||||
| } else if (prim.name() == "ReLU6") { | } else if (prim.name() == "ReLU6") { | ||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| } | } | ||||
| auto alpha = GetValue<float>(prim.GetAttr("alpha")); | |||||
| attr->alpha = alpha; | |||||
| this->primitive_->value.value = attr.release(); | this->primitive_->value.value = attr.release(); | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| @@ -19,7 +19,27 @@ namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_ApplyMomentum; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_ApplyMomentum) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::ApplyMomentumT>(); | |||||
| this->primitive_->value.value = attr.release(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <memory> | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| @@ -31,6 +32,7 @@ class ApplyMomentum : public PrimitiveC { | |||||
| MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); | MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC); | ||||
| ApplyMomentum() = default; | ApplyMomentum() = default; | ||||
| explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| ApplyMomentum() = default; | ApplyMomentum() = default; | ||||
| @@ -41,7 +41,6 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||||
| this->primitive_->value.value = attr; | this->primitive_->value.value = attr; | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | MS_LOG(ERROR) << "primitive value is nullptr"; | ||||
| @@ -24,7 +24,35 @@ float BNGrad::GetMomentum() const { return this->primitive_->value.AsBNGrad()->m | |||||
| void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; } | void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; } | ||||
| void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; } | void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; } | ||||
| int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BNGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BNGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::BNGradInputT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -33,6 +33,7 @@ class BNGrad : public PrimitiveC { | |||||
| explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetEps(float eps); | void SetEps(float eps); | ||||
| void SetMomentum(float momentum); | void SetMomentum(float momentum); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| BNGrad() = default; | BNGrad() = default; | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| @@ -116,8 +116,6 @@ void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema:: | |||||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | ||||
| } | } | ||||
| attr->channelMultiplier = channel_mutiplier; | attr->channelMultiplier = channel_mutiplier; | ||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| @@ -168,8 +166,6 @@ void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim, | |||||
| } else { | } else { | ||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | attr->activationType = schema::ActivationType_NO_ACTIVATION; | ||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| @@ -114,8 +114,6 @@ void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::P | |||||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | ||||
| } | } | ||||
| attr->channelMultiplier = channel_mutiplier; | attr->channelMultiplier = channel_mutiplier; | ||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| @@ -166,8 +164,6 @@ void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim, | |||||
| } else { | } else { | ||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | attr->activationType = schema::ActivationType_NO_ACTIVATION; | ||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/depend.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Depend::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Depend; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Depend) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow)(schema::DependT); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "attr is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_ | |||||
| #define LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_ | |||||
| #include <vector> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Depend : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Depend, PrimitiveC); | |||||
| Depend() = default; | |||||
| explicit Depend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| Depend() = default; | |||||
| #endif | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_SRC_OPS_Depend_H_ | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/flatten_grad.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int FlattenGrad::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| auto input = inputs_.front(); | |||||
| auto output = outputs_.front(); | |||||
| if (input == nullptr || output == nullptr) { | |||||
| MS_LOG(ERROR) << "FlattenGrad input or output is null!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||||
| MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| output->set_data_type(input->data_type()); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| std::vector<int> output_shape(2); | |||||
| output_shape[0] = input_shape[0]; | |||||
| output_shape[1] = 1; | |||||
| for (size_t i = 1; i < input_shape.size(); i++) { | |||||
| output_shape[1] *= input_shape[i]; | |||||
| } | |||||
| output->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_FlattenGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_FlattenGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::FlattenGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateFlattenGrad(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FlattenGrad, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class FlattenGrad : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(FlattenGrad, PrimitiveC); | |||||
| FlattenGrad() = default; | |||||
| explicit FlattenGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| FlattenGrad() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_H_ | |||||
| @@ -136,7 +136,10 @@ | |||||
| #include "src/ops/power_grad.h" | #include "src/ops/power_grad.h" | ||||
| #include "src/ops/softmax_cross_entropy.h" | #include "src/ops/softmax_cross_entropy.h" | ||||
| #include "src/ops/bn_grad.h" | #include "src/ops/bn_grad.h" | ||||
| #include "src/ops/bn_grad_input.h" | |||||
| #include "src/ops/arithmetic_grad.h" | #include "src/ops/arithmetic_grad.h" | ||||
| #include "src/ops/depend.h" | |||||
| #include "src/ops/flatten_grad.h" | |||||
| #endif | #endif | ||||
| @@ -397,6 +400,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri | |||||
| return NewPrimitiveC<BNGradInput>(prim, inputs, quantType); | return NewPrimitiveC<BNGradInput>(prim, inputs, quantType); | ||||
| } else if (op_type == "PowerGrad") { | } else if (op_type == "PowerGrad") { | ||||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||||
| } else if (op_type == "Depend") { | |||||
| return NewPrimitiveC<Depend>(prim, inputs, quantType); | |||||
| } else if (op_type == "FlattenGrad") { | |||||
| return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType); | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; | MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; | ||||
| @@ -638,6 +647,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT | |||||
| return new PowerGrad(primitive); | return new PowerGrad(primitive); | ||||
| case schema::PrimitiveType_BNGradInput: | case schema::PrimitiveType_BNGradInput: | ||||
| return new BNGradInput(primitive); | return new BNGradInput(primitive); | ||||
| case schema::PrimitiveType_SoftmaxCrossEntroy: | |||||
| return new SoftmaxCrossEntroy(primitive); | |||||
| case schema::PrimitiveType_Depend: | |||||
| return new Depend(primitive); | |||||
| case schema::PrimitiveType_FlattenGrad: | |||||
| return new FlattenGrad(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| @@ -24,7 +24,33 @@ std::vector<int> SoftmaxCrossEntropy::GetAxis() const { return this->primitive_- | |||||
| void SoftmaxCrossEntropy::SetAxis(const std::vector<int> &axis) { | void SoftmaxCrossEntropy::SetAxis(const std::vector<int> &axis) { | ||||
| this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis; | this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis; | ||||
| } | } | ||||
| int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_SoftmaxCrossEntropy; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_SoftmaxCrossEntropy) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::SoftmaxCrossEntropyT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| std::vector<int> SoftmaxCrossEntropy::GetAxis() const { | std::vector<int> SoftmaxCrossEntropy::GetAxis() const { | ||||
| @@ -33,7 +33,7 @@ class SoftmaxCrossEntropy : public PrimitiveC { | |||||
| SoftmaxCrossEntropy() = default; | SoftmaxCrossEntropy() = default; | ||||
| explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetAxis(const std::vector<int> &axis); | void SetAxis(const std::vector<int> &axis); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| SoftmaxCrossEntropy() = default; | SoftmaxCrossEntropy() = default; | ||||
| @@ -323,6 +323,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | ||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | ||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | ||||
| } else if (value->isa<mindspore::BoolImm>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = {1}; | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||||
| paramTensor->data.emplace_back(data->value()); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::ValueSequeue>()) { | } else if (value->isa<mindspore::ValueSequeue>()) { | ||||
| MS_LOG(DEBUG) << "Value type is ValueSequence."; | MS_LOG(DEBUG) << "Value type is ValueSequence."; | ||||
| return RET_OK; | return RET_OK; | ||||