| @@ -179,6 +179,7 @@ union PrimitiveType { | |||||
| Conv2DGradInput, | Conv2DGradInput, | ||||
| PoolingGrad, | PoolingGrad, | ||||
| BNGrad, | BNGrad, | ||||
| BNGradInput, | |||||
| ApplyMomentum, | ApplyMomentum, | ||||
| BiasGrad, | BiasGrad, | ||||
| SoftmaxCrossEntropy, | SoftmaxCrossEntropy, | ||||
| @@ -398,7 +398,10 @@ table BNGrad { | |||||
| eps : float; | eps : float; | ||||
| momentum: float; | momentum: float; | ||||
| } | } | ||||
| table BNGradInput { | |||||
| eps : float; | |||||
| momentum: float; | |||||
| } | |||||
| table Scale { | table Scale { | ||||
| axis: int; | axis: int; | ||||
| } | } | ||||
| @@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) { | |||||
| this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type; | this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type; | ||||
| } | } | ||||
| void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; } | void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; } | ||||
| int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_ActivationGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_ActivationGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::ActivationGradT>(); | |||||
| if (prim.name() == "ReLU") { | |||||
| attr->type = schema::ActivationType_RELU; | |||||
| } else if (prim.name() == "Sigmoid") { | |||||
| attr->type = schema::ActivationType_SIGMOID; | |||||
| } else if (prim.name() == "ReLU6") { | |||||
| attr->type = schema::ActivationType_RELU6; | |||||
| } | |||||
| auto alpha = GetValue<float>(prim.GetAttr("alpha")); | |||||
| attr->alpha = alpha; | |||||
| this->primitive_->value.value = attr.release(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <memory> | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| @@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC { | |||||
| explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetType(int type); | void SetType(int type); | ||||
| void SetAlpha(float alpha); | void SetAlpha(float alpha); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| ActivationGrad() = default; | ActivationGrad() = default; | ||||
| @@ -22,7 +22,34 @@ namespace lite { | |||||
| std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } | std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } | ||||
| void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } | void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } | ||||
| int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BiasGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BiasGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::BiasGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC { | |||||
| BiasGrad() = default; | BiasGrad() = default; | ||||
| explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetAxis(const std::vector<int> &axis); | void SetAxis(const std::vector<int> &axis); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| BiasGrad() = default; | BiasGrad() = default; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/bn_grad_input.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; } | |||||
| float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; } | |||||
| void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; } | |||||
| void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; } | |||||
| int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BNGradInput; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::BNGradInputT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_BNGradInput(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } | |||||
| float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); } | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class BNGradInput : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(BNGradInput, PrimitiveC); | |||||
| BNGradInput() = default; | |||||
| explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetEps(float eps); | |||||
| void SetMomentum(float momentum); | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| BNGradInput() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| float GetEps() const; | |||||
| float GetMomentum() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||||
| @@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon | |||||
| void Conv2DGradFilter::SetActivationType(int activation_type) { | void Conv2DGradFilter::SetActivationType(int activation_type) { | ||||
| this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; | this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; | ||||
| } | } | ||||
| void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| int channel_mutiplier = 1; | |||||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||||
| } | |||||
| attr->channelMultiplier = channel_mutiplier; | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim, | |||||
| schema::PrimitiveT *primitive, const int &group) { | |||||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||||
| attr->group = group; | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Conv2DGradFilter; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradFilter) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int group = GetValue<int>(prim.GetAttr("group")); | |||||
| if (group > 1) { | |||||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||||
| } else { | |||||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| @@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC { | |||||
| void SetDilateH(int dilate_h); | void SetDilateH(int dilate_h); | ||||
| void SetHasBias(bool has_bias); | void SetHasBias(bool has_bias); | ||||
| void SetActivationType(int activation_type); | void SetActivationType(int activation_type); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs); | |||||
| void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||||
| #else | #else | ||||
| Conv2DGradFilter() = default; | Conv2DGradFilter() = default; | ||||
| @@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv | |||||
| void Conv2DGradInput::SetActivationType(int activation_type) { | void Conv2DGradInput::SetActivationType(int activation_type) { | ||||
| this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; | this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; | ||||
| } | } | ||||
| void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| int channel_mutiplier = 1; | |||||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||||
| } | |||||
| attr->channelMultiplier = channel_mutiplier; | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim, | |||||
| schema::PrimitiveT *primitive, const int &group) { | |||||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||||
| attr->group = group; | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Conv2DGradInput; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradInput) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int group = GetValue<int>(prim.GetAttr("group")); | |||||
| if (group > 1) { | |||||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||||
| } else { | |||||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| @@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC { | |||||
| void SetDilateH(int dilate_h); | void SetDilateH(int dilate_h); | ||||
| void SetHasBias(bool has_bias); | void SetHasBias(bool has_bias); | ||||
| void SetActivationType(int activation_type); | void SetActivationType(int activation_type); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs); | |||||
| void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||||
| #else | #else | ||||
| Conv2DGradInput() = default; | Conv2DGradInput() = default; | ||||
| @@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling | |||||
| void PoolingGrad::SetRoundMode(int round_mode) { | void PoolingGrad::SetRoundMode(int round_mode) { | ||||
| this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; | this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; | ||||
| } | } | ||||
| int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_PoolingGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_PoolingGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::PoolingGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| if (prim.instance_name() == "MaxPool") { | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||||
| } else if (prim.instance_name() == "MeanPool") { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| } | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("padding")); | |||||
| if (pad_mode == "VALID") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "SAME") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); | |||||
| attr->windowH = kernel_size[2]; | |||||
| attr->windowW = kernel_size[3]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); } | int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <string> | |||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| @@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC { | |||||
| void SetPadLeft(int pad_left); | void SetPadLeft(int pad_left); | ||||
| void SetPadRight(int pad_right); | void SetPadRight(int pad_right); | ||||
| void SetRoundMode(int round_mode); | void SetRoundMode(int round_mode); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| PoolingGrad() = default; | PoolingGrad() = default; | ||||
| @@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad() | |||||
| void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; } | void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; } | ||||
| void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; } | void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; } | ||||
| void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; } | void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; } | ||||
| int PowerGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_PowerGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_PowerGrad) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::PowerGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->power = GetValue<float>(prim.GetAttr("power")); | |||||
| attr->scale = GetValue<float>(prim.GetAttr("scale")); | |||||
| attr->shift = GetValue<float>(prim.GetAttr("shift")); | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); } | float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); } | ||||
| @@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC { | |||||
| void SetPower(float power); | void SetPower(float power); | ||||
| void SetScale(float scale); | void SetScale(float scale); | ||||
| void SetShift(float shift); | void SetShift(float shift); | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| PowerGrad() = default; | PowerGrad() = default; | ||||
| @@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri | |||||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | ||||
| } else if (op_type == "BatchNormGrad") { | } else if (op_type == "BatchNormGrad") { | ||||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2DGradInput") { | |||||
| return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | |||||
| } else if (op_type == "Conv2DGradFilter") { | |||||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | |||||
| } else if (op_type == "BiasGrad") { | |||||
| return NewPrimitiveC<BiasGrad>(prim, inputs, quantType); | |||||
| } else if (op_type == "ActivationGrad") { | |||||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | |||||
| } else if (op_type == "PoolingGrad") { | |||||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | |||||
| } else if (op_type == "BNGradInput") { | |||||
| return NewPrimitiveC<BNGradInput>(prim, inputs, quantType); | |||||
| } else if (op_type == "PowerGrad") { | |||||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||||
| #endif | #endif | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; | MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; | ||||
| @@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT | |||||
| return new ArithmeticGrad(primitive); | return new ArithmeticGrad(primitive); | ||||
| case schema::PrimitiveType_DivGrad: | case schema::PrimitiveType_DivGrad: | ||||
| return new ArithmeticGrad(primitive); | return new ArithmeticGrad(primitive); | ||||
| case schema::PrimitiveType_PowerGrad: | |||||
| return new PowerGrad(primitive); | |||||
| case schema::PrimitiveType_BNGradInput: | |||||
| return new BNGradInput(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||