Merge pull request !7337 from liubuyu/mastertags/v1.1.0
| @@ -124,6 +124,7 @@ inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("Poolin | |||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||
| inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/add.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto add_prim = primitive->cast<PrimTensorAddPtr>(); | |||
| MS_EXCEPTION_IF_NULL(add_prim); | |||
| auto op_name = add_prim->name(); | |||
| return BroadCastInferShape(op_name, input_args); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| types.emplace("y", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| #define MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameTensorAdd = "TensorAdd"; | |||
| class TensorAdd : public PrimitiveC { | |||
| public: | |||
| TensorAdd() : PrimitiveC(kNameTensorAdd) { InitIOName({"x", "y"}, {"output"}); } | |||
| ~TensorAdd() = default; | |||
| MS_DECLARE_PARENT(TensorAdd, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimTensorAddPtr = std::shared_ptr<TensorAdd>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_ADD_H_ | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/avg_pool.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); } | |||
| void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); } | |||
| void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); } | |||
| std::vector<int> AvgPool::get_strides() const { | |||
| auto value_ptr = GetAttr("strides"); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> AvgPool::get_kernel_size() const { | |||
| auto value_ptr = GetAttr("ksize"); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::string AvgPool::get_padding() const { | |||
| auto value_ptr = GetAttr("padding"); | |||
| return GetValue<std::string>(value_ptr); | |||
| } | |||
| void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) { | |||
| auto prim_name = this->name(); | |||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||
| this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name)); | |||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true)); | |||
| this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true)); | |||
| } | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto pool_prim = primitive->cast<PrimAvgPoolPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pool_prim); | |||
| auto op_name = pool_prim->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | |||
| auto kernel_size = pool_prim->get_kernel_size(); | |||
| auto pad_mode = pool_prim->get_padding(); | |||
| auto batch = in_shape[0]; | |||
| auto channel = in_shape[1]; | |||
| auto in_h = in_shape[2]; | |||
| auto in_w = in_shape[3]; | |||
| auto strides = pool_prim->get_strides(); | |||
| auto kernel_h = kernel_size[2]; | |||
| auto kernel_w = kernel_size[3]; | |||
| auto stride_h = strides[2]; | |||
| auto stride_w = strides[3]; | |||
| int out_h = -1; | |||
| int out_w = -1; | |||
| if (pad_mode == "valid") { | |||
| out_h = ceil((in_h - (kernel_h - 1)) / stride_h); | |||
| out_w = ceil((in_w - (kernel_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| out_h = ceil(in_h / stride_h); | |||
| out_w = ceil(in_w / stride_w); | |||
| } | |||
| std::vector<int> out_shape = {batch, channel, out_h, out_w}; | |||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int a) { return a <= 0; })) { | |||
| MS_LOG(EXCEPTION) << "Kernel size is not valid."; | |||
| } | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return input_args[0]->BuildType(); | |||
| } | |||
| AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| #define MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameAvgPool = "AvgPool"; | |||
| class AvgPool : public PrimitiveC { | |||
| public: | |||
| AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); } | |||
| ~AvgPool() = default; | |||
| MS_DECLARE_PARENT(AvgPool, PrimitiveC); | |||
| void Init(const std::vector<int> &kernel_size = {1}, const std::vector<int> &stride = {1}, | |||
| const std::string &padding = "valid"); | |||
| void set_padding(const std::string &pad); | |||
| void set_kernel_size(const std::vector<int> &kernel_size); | |||
| void set_strides(const std::vector<int> &strides); | |||
| std::vector<int> get_kernel_size() const; | |||
| std::vector<int> get_strides() const; | |||
| std::string get_padding() const; | |||
| }; | |||
| AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAvgPoolPtr = std::shared_ptr<AvgPool>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_ | |||
| @@ -25,22 +25,12 @@ | |||
| namespace mindspore { | |||
| namespace { | |||
| constexpr auto kKernelSize = "kernel_size"; | |||
| constexpr auto kStride = "stride"; | |||
| constexpr auto kDilation = "dilation"; | |||
| constexpr auto kPadMode = "pad_mode"; | |||
| constexpr auto kPad = "pad"; | |||
| constexpr auto kMode = "mode"; | |||
| constexpr auto kGroup = "group"; | |||
| constexpr auto kOutputChannel = "output channel"; | |||
| constexpr auto kPadList = "pad_list"; | |||
| constexpr auto kConv2DName = "Conv2D"; | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimConv2dPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | |||
| @@ -48,17 +38,17 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]", | |||
| w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->GetOutputChannel(); | |||
| auto out_channel = conv_prim->get_output_channel(); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); | |||
| std::vector<int> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, | |||
| conv_prim->name()); | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = conv_prim->GetStride(); | |||
| auto dilation = conv_prim->GetDilation(); | |||
| auto stride = conv_prim->get_stride(); | |||
| auto dilation = conv_prim->get_dilation(); | |||
| auto stride_h = stride[2]; | |||
| auto stride_w = stride[3]; | |||
| auto dilation_h = dilation[2]; | |||
| @@ -66,7 +56,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| int h_out = -1; | |||
| int w_out = -1; | |||
| std::vector<int> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->GetPadMode(); | |||
| auto pad_mode = conv_prim->get_pad_mode(); | |||
| if (pad_mode == "valid") { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| @@ -82,18 +72,18 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == "pad") { | |||
| std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->GetPad()[0]; | |||
| auto pad_bottom = conv_prim->GetPad()[1]; | |||
| auto pad_right = conv_prim->GetPad()[2]; | |||
| auto pad_left = conv_prim->GetPad()[3]; | |||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->get_pad()[0]; | |||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||
| auto pad_right = conv_prim->get_pad()[2]; | |||
| auto pad_left = conv_prim->get_pad()[3]; | |||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->SetPadList(pad_list); | |||
| conv_prim->set_pad_list(pad_list); | |||
| std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| @@ -122,44 +112,45 @@ void Conv2D::Init(int out_channel, const std::vector<int> &kernel_size, int mode | |||
| auto prim_name = this->name(); | |||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||
| this->AddAttr("offset_a", MakeValue(0)); | |||
| this->SetKernelSize(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||
| this->SetStride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); | |||
| this->SetDilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); | |||
| this->SetPadMode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||
| CheckAndConvertUtils::CheckInteger("pad size", pad.size(), kEqual, 4, prim_name); | |||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||
| this->set_stride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); | |||
| this->set_dilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); | |||
| this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name); | |||
| if (pad_mode == "pad") { | |||
| for (auto item : pad) { | |||
| CheckAndConvertUtils::Check("pad item", item, kGreaterEqual, "zeros list", 0, prim_name); | |||
| CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name); | |||
| } | |||
| } else { | |||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros list", {0, 0, 0, 0}, prim_name); | |||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | |||
| } | |||
| this->SetPad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||
| this->SetMode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); | |||
| this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||
| this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); | |||
| this->set_out_channel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); | |||
| this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||
| } | |||
| std::vector<int> Conv2D::GetKernelSize() const { | |||
| std::vector<int> Conv2D::get_kernel_size() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2D::GetStride() const { | |||
| std::vector<int> Conv2D::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2D::GetDilation() const { | |||
| std::vector<int> Conv2D::get_dilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::string Conv2D::GetPadMode() const { | |||
| std::string Conv2D::get_pad_mode() const { | |||
| auto value_ptr = this->GetAttr(kPadMode); | |||
| return GetValue<string>(value_ptr); | |||
| } | |||
| std::vector<int> Conv2D::GetPad() const { | |||
| std::vector<int> Conv2D::get_pad() const { | |||
| auto value_ptr = this->GetAttr(kPad); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| int Conv2D::GetMode() const { | |||
| int Conv2D::get_mode() const { | |||
| auto value_ptr = this->GetAttr(kMode); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| @@ -168,20 +159,22 @@ int Conv2D::get_group() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int Conv2D::GetOutputChannel() const { | |||
| int Conv2D::get_output_channel() const { | |||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| void Conv2D::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } | |||
| void Conv2D::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void Conv2D::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void Conv2D::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void Conv2D::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void Conv2D::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void Conv2D::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void Conv2D::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void Conv2D::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| void Conv2D::set_kernel_size(const std::vector<int> &kernel_size) { | |||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||
| } | |||
| void Conv2D::set_stride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void Conv2D::set_dilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void Conv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void Conv2D::set_pad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void Conv2D::set_mode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void Conv2D::set_group(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void Conv2D::set_out_channel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void Conv2D::set_pad_list(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| @@ -21,10 +21,12 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/op_utils.h" | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kConv2DName = "Conv2D"; | |||
| class Conv2D : public PrimitiveC { | |||
| public: | |||
| Conv2D(); | |||
| @@ -33,23 +35,23 @@ class Conv2D : public PrimitiveC { | |||
| void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | |||
| const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | |||
| const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | |||
| std::vector<int> GetKernelSize() const; | |||
| std::vector<int> GetStride() const; | |||
| std::vector<int> GetDilation() const; | |||
| std::string GetPadMode() const; | |||
| std::vector<int> GetPad() const; | |||
| int GetMode() const; | |||
| std::vector<int> get_kernel_size() const; | |||
| std::vector<int> get_stride() const; | |||
| std::vector<int> get_dilation() const; | |||
| std::string get_pad_mode() const; | |||
| std::vector<int> get_pad() const; | |||
| int get_mode() const; | |||
| int get_group() const; | |||
| int GetOutputChannel() const; | |||
| void SetKernelSize(const std::vector<int> &kernel_size); | |||
| void SetStride(const std::vector<int> &stride); | |||
| void SetDilation(const std::vector<int> &dilation); | |||
| void SetPadMode(const std::string &pad_mode); | |||
| void SetPad(const std::vector<int> &pad); | |||
| void SetMode(int mode); | |||
| void SetGroup(int group); | |||
| void SetOutChannel(int output_channel); | |||
| void SetPadList(const std::vector<int> &pad_list); | |||
| int get_output_channel() const; | |||
| void set_kernel_size(const std::vector<int> &kernel_size); | |||
| void set_stride(const std::vector<int> &stride); | |||
| void set_dilation(const std::vector<int> &dilation); | |||
| void set_pad_mode(const std::string &pad_mode); | |||
| void set_pad(const std::vector<int> &pad); | |||
| void set_mode(int mode); | |||
| void set_group(int group); | |||
| void set_out_channel(int output_channel); | |||
| void set_pad_list(const std::vector<int> &pad_list); | |||
| }; | |||
| AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| @@ -0,0 +1,199 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/depthwise_conv2d.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void DepthWiseConv2D::Init(int channel_multiplier, const std::vector<int> &kernel_size, int mode, | |||
| const std::string &pad_mode, const std::vector<int> &pad, const std::vector<int> &stride, | |||
| const std::vector<int> &dilation, int group) { | |||
| auto prim_name = this->name(); | |||
| this->AddAttr("data_format", MakeValue("NCHW")); | |||
| this->AddAttr("offset_a", MakeValue(0)); | |||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); | |||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | |||
| auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); | |||
| if (strides[0] != strides[1]) { | |||
| MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] | |||
| << ", width " << strides[1]; | |||
| } | |||
| this->set_stride(strides); | |||
| auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); | |||
| if (dilations[0] != dilations[1]) { | |||
| MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] | |||
| << ", width " << dilations[1]; | |||
| } | |||
| this->set_dilation(dilations); | |||
| this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name); | |||
| if (pad_mode == "pad") { | |||
| for (auto item : pad) { | |||
| CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name); | |||
| } | |||
| } else { | |||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | |||
| } | |||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||
| this->set_out_channel( | |||
| CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); | |||
| this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); | |||
| } | |||
| std::vector<int> DepthWiseConv2D::get_kernel_size() const { | |||
| auto value_ptr = GetAttr(kKernelSize); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> DepthWiseConv2D::get_stride() const { | |||
| auto value_ptr = GetAttr(kStride); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::vector<int> DepthWiseConv2D::get_dilation() const { | |||
| auto value_ptr = GetAttr(kDilation); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| std::string DepthWiseConv2D::get_pad_mode() const { | |||
| auto value_ptr = this->GetAttr(kPadMode); | |||
| return GetValue<string>(value_ptr); | |||
| } | |||
| std::vector<int> DepthWiseConv2D::get_pad() const { | |||
| auto value_ptr = this->GetAttr(kPad); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| int DepthWiseConv2D::get_mode() const { | |||
| auto value_ptr = this->GetAttr(kMode); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int DepthWiseConv2D::get_group() const { | |||
| auto value_ptr = this->GetAttr(kGroup); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| int DepthWiseConv2D::get_output_channel() const { | |||
| auto value_ptr = this->GetAttr(kOutputChannel); | |||
| return GetValue<int>(value_ptr); | |||
| } | |||
| void DepthWiseConv2D::set_kernel_size(const std::vector<int> &kernel_size) { | |||
| this->AddAttr(kKernelSize, MakeValue(kernel_size)); | |||
| } | |||
| void DepthWiseConv2D::set_stride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); } | |||
| void DepthWiseConv2D::set_dilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } | |||
| void DepthWiseConv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } | |||
| void DepthWiseConv2D::set_pad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); } | |||
| void DepthWiseConv2D::set_mode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } | |||
| void DepthWiseConv2D::set_group(int group) { this->AddAttr(kGroup, MakeValue(group)); } | |||
| void DepthWiseConv2D::set_out_channel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } | |||
| void DepthWiseConv2D::set_pads(const std::vector<int> &pad_list) { this->AddAttr(kPads, MakeValue(pad_list)); } | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto conv_prim = primitive->cast<PrimDepthWiseConv2DPtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv_prim); | |||
| auto prim_name = conv_prim->name(); | |||
| CheckAndConvertUtils::CheckInRange("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); | |||
| CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], conv_prim->name()); | |||
| auto out_channel = conv_prim->get_output_channel(); | |||
| std::vector<int> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, | |||
| conv_prim->name()); | |||
| auto kernel_size_n = w_shape[0]; | |||
| if (kernel_size_n != 1) { | |||
| MS_EXCEPTION(ValueError) << "The batch of input weeight should be 1, but got " << kernel_size_n; | |||
| } | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = conv_prim->get_stride(); | |||
| auto dilation = conv_prim->get_dilation(); | |||
| auto stride_h = stride[2]; | |||
| auto stride_w = stride[3]; | |||
| auto dilation_h = dilation[2]; | |||
| auto dilation_w = dilation[3]; | |||
| int h_out = -1; | |||
| int w_out = -1; | |||
| std::vector<int> pad_list(4, 0); | |||
| auto pad_mode = conv_prim->get_pad_mode(); | |||
| if (pad_mode == "valid") { | |||
| h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); | |||
| w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); | |||
| } else if (pad_mode == "same") { | |||
| h_out = ceil(x_shape[2] / stride_h); | |||
| w_out = ceil(x_shape[3] / stride_w); | |||
| auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); | |||
| pad_list.emplace_back(floor(pad_needed_h / 2)); | |||
| pad_list.emplace_back(pad_needed_h / 2); | |||
| auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); | |||
| auto pad_left = floor(pad_needed_w / 2); | |||
| pad_list.emplace_back(pad_left); | |||
| pad_list.emplace_back(pad_needed_h - pad_left); | |||
| } else if (pad_mode == "pad") { | |||
| std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); | |||
| auto pad_top = conv_prim->get_pad()[0]; | |||
| auto pad_bottom = conv_prim->get_pad()[1]; | |||
| auto pad_right = conv_prim->get_pad()[2]; | |||
| auto pad_left = conv_prim->get_pad()[3]; | |||
| h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; | |||
| w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| conv_prim->set_pads(pad_list); | |||
| std::vector<int> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| types.emplace("w", input_args[1]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); | |||
| if (infer_type == kNumberTypeInt8) { | |||
| return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32)); | |||
| } | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H | |||
| #define MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameDepthWiseConv2D = "DepthwiseConv2dNative"; | |||
| class DepthWiseConv2D : public PrimitiveC { | |||
| public: | |||
| DepthWiseConv2D() : PrimitiveC(kNameDepthWiseConv2D) { InitIOName({"x", "w"}, {"output"}); } | |||
| ~DepthWiseConv2D() = default; | |||
| MS_DECLARE_PARENT(DepthWiseConv2D, PrimitiveC); | |||
| void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", | |||
| const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, | |||
| const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); | |||
| std::vector<int> get_kernel_size() const; | |||
| std::vector<int> get_stride() const; | |||
| std::vector<int> get_dilation() const; | |||
| std::string get_pad_mode() const; | |||
| std::vector<int> get_pad() const; | |||
| int get_mode() const; | |||
| int get_group() const; | |||
| int get_output_channel() const; | |||
| void set_kernel_size(const std::vector<int> &kernel_size); | |||
| void set_stride(const std::vector<int> &stride); | |||
| void set_dilation(const std::vector<int> &dilation); | |||
| void set_pad_mode(const std::string &pad_mode); | |||
| void set_pad(const std::vector<int> &pad); | |||
| void set_mode(int mode); | |||
| void set_group(int group); | |||
| void set_out_channel(int output_channel); | |||
| void set_pads(const std::vector<int> &pad_list); | |||
| }; | |||
| AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimDepthWiseConv2DPtr = std::shared_ptr<DepthWiseConv2D>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_CONV_UTILS_H | |||
| #define MINDSPORE_CORE_C_OPS_CONV_UTILS_H | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kKernelSize = "kernel_size"; | |||
| constexpr auto kStride = "stride"; | |||
| constexpr auto kDilation = "dilation"; | |||
| constexpr auto kPadMode = "pad_mode"; | |||
| constexpr auto kPad = "pad"; | |||
| constexpr auto kPads = "pads"; | |||
| constexpr auto kMode = "mode"; | |||
| constexpr auto kGroup = "group"; | |||
| constexpr auto kOutputChannel = "output_channel"; | |||
| constexpr auto kPadList = "pad_list"; | |||
| constexpr auto kAxis = "axis"; | |||
| const std::set<TypeId> common_valid_types = { | |||
| kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | |||
| kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||
| abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_CONV_UTILS_H | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "c_ops/op_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_LOG(INFO) << "Do infer shape for op " << op_name; | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name); | |||
| if (x_shape == y_shape) { | |||
| return std::make_shared<abstract::Shape>(x_shape); | |||
| } | |||
| auto x_length = x_shape.size(); | |||
| auto y_length = y_shape.size(); | |||
| auto length = x_length < y_length ? x_length : y_length; | |||
| std::vector<int> broadcast_shape; | |||
| if (x_length == length) { | |||
| std::copy(y_shape.begin(), y_shape.end() - length, std::back_inserter(broadcast_shape)); | |||
| } else { | |||
| std::copy(x_shape.begin(), x_shape.end() - length, std::back_inserter(broadcast_shape)); | |||
| } | |||
| for (int i = -length; i < 0; i++) { | |||
| if (x_shape[x_length + i] == 1) { | |||
| broadcast_shape.push_back(y_shape[y_length + i]); | |||
| } else if (y_shape[y_length + i] == 1) { | |||
| broadcast_shape.push_back(x_shape[x_length + i]); | |||
| } else if (x_shape[x_length + i] == y_shape[y_length + i]) { | |||
| broadcast_shape.push_back(x_shape[x_length + i]); | |||
| } else { | |||
| MS_EXCEPTION(ValueError) << "For op " << op_name << ", the two input can not broadcast"; | |||
| } | |||
| } | |||
| return std::make_shared<abstract::Shape>(broadcast_shape); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/relu6.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| abstract::ShapePtr Relu6InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto x = input_args[0]->GetShapeTrack(); | |||
| auto shape_element = x->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_element); | |||
| return shape_element; | |||
| } | |||
| TypePtr Relu6InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(Relu6InferType(primitive, input_args), | |||
| Relu6InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_RELU6_H_ | |||
| #define MINDSPORE_CORE_C_OPS_RELU6_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameRelu6 = "Relu6"; | |||
| class Relu6 : public PrimitiveC { | |||
| public: | |||
| Relu6() : PrimitiveC(kNameRelu6) { InitIOName({"x"}, {"output"}); } | |||
| ~Relu6() = default; | |||
| MS_DECLARE_PARENT(Relu6, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimRelu6Ptr = std::shared_ptr<Relu6>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_RELU6_H_ | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/reshape.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| // to do | |||
| return nullptr; | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| // to do | |||
| return nullptr; | |||
| } | |||
| AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_RESHAPE_H_ | |||
| #define MINDSPORE_CORE_C_OPS_RESHAPE_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameReshape = "Reshape"; | |||
| class Reshape : public PrimitiveC { | |||
| public: | |||
| Reshape() : PrimitiveC(kNameReshape) { InitIOName({"tensor", "shape"}, {"output"}); } | |||
| ~Reshape() = default; | |||
| MS_DECLARE_PARENT(Reshape, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimTensorAddPtr = std::shared_ptr<Reshape>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_ | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/softmax.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void Softmax::set_axis(const std::vector<int> &axis) { this->set_attr(kAxis, MakeValue(axis)); } | |||
| void Softmax::Init(int axis) { | |||
| auto op_name = this->name(); | |||
| std::vector<int> axis_vec = {axis}; | |||
| CheckAndConvertUtils::CheckInteger("axis_len", axis_vec.size(), kEqual, 1, op_name); | |||
| auto rank = axis_vec.size(); | |||
| for (auto &item : axis_vec) { | |||
| CheckAndConvertUtils::CheckInRange("axis", item, kIncludeLeft, {-rank, rank}, op_name); | |||
| } | |||
| this->set_axis(axis_vec); | |||
| } | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto softmax_prim = primitive->cast<PrimSoftmaxPtr>(); | |||
| MS_EXCEPTION_IF_NULL(softmax_prim); | |||
| auto op_name = softmax_prim->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| return std::make_shared<abstract::Shape>(in_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64}; | |||
| types.emplace("x", input_args[0]->BuildType()); | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| return TypeIdToType(infer_type); | |||
| } | |||
| AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_SOFTMAX_H_ | |||
| #define MINDSPORE_CORE_C_OPS_SOFTMAX_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameSoftmax = "Softmax"; | |||
| class Softmax : public PrimitiveC { | |||
| public: | |||
| Softmax() : PrimitiveC(kNameSoftmax) { InitIOName({"x"}, {"output"}); } | |||
| ~Softmax() = default; | |||
| MS_DECLARE_PARENT(Softmax, PrimitiveC); | |||
| void Init(int axis = 1); | |||
| void set_axis(const std::vector<int> &axis); | |||
| }; | |||
| AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSoftmaxPtr = std::shared_ptr<Softmax>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_SOFTMAX_H_ | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "c_ops/squeeze.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "c_ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| void Squeeze::set_axis(const std::vector<int> &axis) { this->set_attr(kAxis, MakeValue(axis)); } | |||
| void Squeeze::Init(const std::vector<int> &axis) { this->set_axis(axis); } | |||
| std::vector<int> Squeeze::get_axis() const { | |||
| auto value_ptr = this->GetAttr(kAxis); | |||
| return GetValue<std::vector<int>>(value_ptr); | |||
| } | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto squeeze_prim = primitive->cast<PrimSqueezePtr>(); | |||
| MS_EXCEPTION_IF_NULL(squeeze_prim); | |||
| auto op_name = squeeze_prim->name(); | |||
| auto axis = squeeze_prim->get_axis(); | |||
| std::vector<int> infer_shape; | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); | |||
| auto len = in_shape.size(); | |||
| if (axis.empty()) { | |||
| std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), | |||
| [](int value) { return value != 1; }); | |||
| } else { | |||
| for (auto &item : axis) { | |||
| CheckAndConvertUtils::CheckInRange("axis_or_elememt", item, kIncludeBoth, {-len, len + 1}, op_name); | |||
| auto idx = item >= 0 ? item : len + item; | |||
| if (in_shape[idx] != 1) { | |||
| MS_EXCEPTION(ValueError) << "Cannot select an axis to squeeze out which has size not equal to one."; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < len; i++) { | |||
| auto it = std::find(axis.begin(), axis.end(), i); | |||
| auto it2 = std::find(axis.begin(), axis.end(), i - len); | |||
| if (!(it != axis.end() || it2 != axis.end())) { | |||
| infer_shape.push_back(in_shape[i]); | |||
| } | |||
| } | |||
| } | |||
| return std::make_shared<abstract::Shape>(infer_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return input_args[0]->BuildType(); | |||
| } | |||
| AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_SQUEEZE_H_ | |||
| #define MINDSPORE_CORE_C_OPS_SQUEEZE_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "c_ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| constexpr auto kNameSqueeze = "Squeeze"; | |||
| class Squeeze : public PrimitiveC { | |||
| public: | |||
| Squeeze() : PrimitiveC(kNameSqueeze) { InitIOName({"x"}, {"output"}); } | |||
| ~Squeeze() = default; | |||
| MS_DECLARE_PARENT(Squeeze, PrimitiveC); | |||
| void Init(const std::vector<int> &axis = {}); | |||
| void set_axis(const std::vector<int> &axis); | |||
| std::vector<int> get_axis() const; | |||
| }; | |||
| AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSqueezePtr = std::shared_ptr<Squeeze>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_SQUEEZE_H_ | |||