| @@ -179,6 +179,7 @@ union PrimitiveType { | |||
| Conv2DGradInput, | |||
| PoolingGrad, | |||
| BNGrad, | |||
| BNGradInput, | |||
| ApplyMomentum, | |||
| BiasGrad, | |||
| SoftmaxCrossEntropy, | |||
| @@ -398,7 +398,10 @@ table BNGrad { | |||
| eps : float; | |||
| momentum: float; | |||
| } | |||
| table BNGradInput { | |||
| eps : float; | |||
| momentum: float; | |||
| } | |||
| table Scale { | |||
| axis: int; | |||
| } | |||
| @@ -25,6 +25,36 @@ void ActivationGrad::SetType(int type) { | |||
| this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type; | |||
| } | |||
| void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; } | |||
| int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_ActivationGrad; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_ActivationGrad) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| auto attr = std::make_unique<schema::ActivationGradT>(); | |||
| if (prim.name() == "ReLU") { | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (prim.name() == "Sigmoid") { | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (prim.name() == "ReLU6") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } | |||
| auto alpha = GetValue<float>(prim.GetAttr("alpha")); | |||
| attr->alpha = alpha; | |||
| this->primitive_->value.value = attr.release(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -33,6 +34,7 @@ class ActivationGrad : public PrimitiveC { | |||
| explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetType(int type); | |||
| void SetAlpha(float alpha); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| ActivationGrad() = default; | |||
| @@ -22,7 +22,34 @@ namespace lite { | |||
| std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } | |||
| void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } | |||
| int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_BiasGrad; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_BiasGrad) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::BiasGradT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -33,7 +33,7 @@ class BiasGrad : public PrimitiveC { | |||
| BiasGrad() = default; | |||
| explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetAxis(const std::vector<int> &axis); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| BiasGrad() = default; | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/bn_grad_input.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; } | |||
| float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; } | |||
| void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; } | |||
| void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; } | |||
| int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_BNGradInput; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::BNGradInputT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_BNGradInput(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } | |||
| float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); } | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BNGradInput : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(BNGradInput, PrimitiveC); | |||
| BNGradInput() = default; | |||
| explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetEps(float eps); | |||
| void SetMomentum(float momentum); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| BNGradInput() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| float GetEps() const; | |||
| float GetMomentum() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ | |||
| @@ -66,7 +66,133 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon | |||
| void Conv2DGradFilter::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| attr->padMode = schema::PadMode_SAME; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| if (prim.GetAttr("activation_name") != nullptr) { | |||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||
| attr->activationType = kActivationTypeMap[activate_name]; | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| int channel_mutiplier = 1; | |||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| } | |||
| attr->channelMultiplier = channel_mutiplier; | |||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim, | |||
| schema::PrimitiveT *primitive, const int &group) { | |||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||
| attr->group = group; | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| attr->padMode = schema::PadMode_SAME; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| if (prim.GetAttr("activation_name") != nullptr) { | |||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||
| attr->activationType = kActivationTypeMap[activate_name]; | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Conv2DGradFilter; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradFilter) { | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| int group = GetValue<int>(prim.GetAttr("group")); | |||
| if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } else { | |||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -20,6 +20,8 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -48,6 +50,10 @@ class Conv2DGradFilter : public PrimitiveC { | |||
| void SetDilateH(int dilate_h); | |||
| void SetHasBias(bool has_bias); | |||
| void SetActivationType(int activation_type); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs); | |||
| void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||
| #else | |||
| Conv2DGradFilter() = default; | |||
| @@ -64,7 +64,133 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv | |||
| void Conv2DGradInput::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| attr->padMode = schema::PadMode_SAME; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| if (prim.GetAttr("activation_name") != nullptr) { | |||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||
| attr->activationType = kActivationTypeMap[activate_name]; | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| int channel_mutiplier = 1; | |||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| } | |||
| attr->channelMultiplier = channel_mutiplier; | |||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim, | |||
| schema::PrimitiveT *primitive, const int &group) { | |||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||
| attr->group = group; | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| attr->padMode = schema::PadMode_SAME; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| if (prim.GetAttr("activation_name") != nullptr) { | |||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||
| attr->activationType = kActivationTypeMap[activate_name]; | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Conv2DGradInput; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Conv2DGradInput) { | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| int group = GetValue<int>(prim.GetAttr("group")); | |||
| if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } else { | |||
| PopulaterConv2DSingleGroup(prim, this->primitive_, group); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -20,6 +20,8 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -48,6 +50,10 @@ class Conv2DGradInput : public PrimitiveC { | |||
| void SetDilateH(int dilate_h); | |||
| void SetHasBias(bool has_bias); | |||
| void SetActivationType(int activation_type); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs); | |||
| void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||
| #else | |||
| Conv2DGradInput() = default; | |||
| @@ -52,7 +52,64 @@ void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPooling | |||
| void PoolingGrad::SetRoundMode(int round_mode) { | |||
| this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; | |||
| } | |||
| int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_PoolingGrad; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_PoolingGrad) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::PoolingGradT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| if (prim.instance_name() == "MaxPool") { | |||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||
| } else if (prim.instance_name() == "MeanPool") { | |||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||
| } | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("padding")); | |||
| if (pad_mode == "VALID") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "SAME") { | |||
| attr->padMode = schema::PadMode_SAME; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); | |||
| attr->windowH = kernel_size[2]; | |||
| attr->windowW = kernel_size[3]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/ops/primitive_c.h" | |||
| @@ -44,6 +45,7 @@ class PoolingGrad : public PrimitiveC { | |||
| void SetPadLeft(int pad_left); | |||
| void SetPadRight(int pad_right); | |||
| void SetRoundMode(int round_mode); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| PoolingGrad() = default; | |||
| @@ -26,7 +26,36 @@ float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad() | |||
| void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; } | |||
| void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; } | |||
| void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; } | |||
| int PowerGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_PowerGrad; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_PowerGrad) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::PowerGradT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->power = GetValue<float>(prim.GetAttr("power")); | |||
| attr->scale = GetValue<float>(prim.GetAttr("scale")); | |||
| attr->shift = GetValue<float>(prim.GetAttr("shift")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); } | |||
| @@ -34,6 +34,7 @@ class PowerGrad : public PrimitiveC { | |||
| void SetPower(float power); | |||
| void SetScale(float scale); | |||
| void SetShift(float shift); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| PowerGrad() = default; | |||
| @@ -383,6 +383,20 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri | |||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | |||
| } else if (op_type == "BatchNormGrad") { | |||
| return NewPrimitiveC<BNGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2DGradInput") { | |||
| return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2DGradFilter") { | |||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | |||
| } else if (op_type == "BiasGrad") { | |||
| return NewPrimitiveC<BiasGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "ActivationGrad") { | |||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "PoolingGrad") { | |||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "BNGradInput") { | |||
| return NewPrimitiveC<BNGradInput>(prim, inputs, quantType); | |||
| } else if (op_type == "PowerGrad") { | |||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; | |||
| @@ -620,6 +634,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT | |||
| return new ArithmeticGrad(primitive); | |||
| case schema::PrimitiveType_DivGrad: | |||
| return new ArithmeticGrad(primitive); | |||
| case schema::PrimitiveType_PowerGrad: | |||
| return new PowerGrad(primitive); | |||
| case schema::PrimitiveType_BNGradInput: | |||
| return new BNGradInput(primitive); | |||
| #endif | |||
| default: | |||