From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjinpull/14442/MERGE
| @@ -18,6 +18,8 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| @@ -73,26 +75,7 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr | |||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node); | ||||
| auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node); | auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node); | ||||
| MS_EXCEPTION_IF_NULL(used_cast_node_list); | MS_EXCEPTION_IF_NULL(used_cast_node_list); | ||||
| std::unordered_map<string, size_t> format_counter; | |||||
| for (const auto &node_info : *used_cast_node_list) { | |||||
| MS_EXCEPTION_IF_NULL(node_info.first); | |||||
| auto cast_out_node = node_info.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cast_out_node); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); | |||||
| for (size_t index = 0; index < input_num; ++index) { | |||||
| if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first != | |||||
| cast_node) { | |||||
| continue; | |||||
| } | |||||
| auto format = AnfAlgo::GetInputFormat(cast_out_node, index); | |||||
| auto it = format_counter.find(format); | |||||
| if (it == format_counter.end()) { | |||||
| format_counter[format] = 1; | |||||
| } else { | |||||
| it->second++; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::unordered_map<string, size_t> format_counter = CalculateFormat(used_cast_node_list, cast_node); | |||||
| auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0); | auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0); | ||||
| string convert_format = kOpFormat_DEFAULT; | string convert_format = kOpFormat_DEFAULT; | ||||
| if (cast_input_format == kOpFormat_DEFAULT) { | if (cast_input_format == kOpFormat_DEFAULT) { | ||||
| @@ -121,5 +104,33 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr | |||||
| SetCastFormat(cast_node, convert_format); | SetCastFormat(cast_node, convert_format); | ||||
| } | } | ||||
| } | } | ||||
| std::unordered_map<string, size_t> ConvertCastFormat::CalculateFormat( | |||||
| const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list, | |||||
| const CNodePtr &cast_node) const { | |||||
| MS_EXCEPTION_IF_NULL(used_cast_node_list); | |||||
| MS_EXCEPTION_IF_NULL(cast_node); | |||||
| std::unordered_map<string, size_t> format_counter; | |||||
| for (const auto &node_info : *used_cast_node_list) { | |||||
| MS_EXCEPTION_IF_NULL(node_info.first); | |||||
| auto cast_out_node = node_info.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cast_out_node); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); | |||||
| for (size_t index = 0; index < input_num; ++index) { | |||||
| if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first != | |||||
| cast_node) { | |||||
| continue; | |||||
| } | |||||
| auto format = AnfAlgo::GetInputFormat(cast_out_node, index); | |||||
| auto it = format_counter.find(format); | |||||
| if (it == format_counter.end()) { | |||||
| format_counter[format] = 1; | |||||
| } else { | |||||
| it->second++; | |||||
| } | |||||
| } | |||||
| } | |||||
| return format_counter; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,10 @@ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ | #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| @@ -30,6 +34,9 @@ class ConvertCastFormat : public PatternProcessPass { | |||||
| const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| std::unordered_map<string, size_t> CalculateFormat( | |||||
| const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list, | |||||
| const CNodePtr &cast_node) const; | |||||
| void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const; | void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const; | ||||
| void SetCastFormat(const CNodePtr &cast_node, const string &format) const; | void SetCastFormat(const CNodePtr &cast_node, const string &format) const; | ||||
| }; | }; | ||||
| @@ -143,18 +143,22 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_ | |||||
| } | } | ||||
| // insert depend | // insert depend | ||||
| if (origin_format != cur_format || origin_type != cur_type) { | if (origin_format != cur_format || origin_type != cur_type) { | ||||
| std::vector<AnfNodePtr> depend_nodes; | |||||
| if (get_item.get() != nullptr) { | |||||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; | |||||
| } else { | |||||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node}; | |||||
| } | |||||
| final_node = func_graph->NewCNode(depend_nodes); | |||||
| final_node = MakeDependency(get_item, final_node, cnode, func_graph); | |||||
| MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); | MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); | ||||
| } | } | ||||
| return final_node; | return final_node; | ||||
| } | } | ||||
| CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, | |||||
| const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { | |||||
| std::vector<AnfNodePtr> depend_nodes; | |||||
| if (get_item != nullptr) { | |||||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; | |||||
| } else { | |||||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node}; | |||||
| } | |||||
| return func_graph->NewCNode(depend_nodes); | |||||
| } | |||||
| CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | ||||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | const std::shared_ptr<kernel::OpInfo> &op_info) const { | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| @@ -33,6 +33,8 @@ class DealRefTransAndCast : public TransDataSplit { | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| CNodePtr MakeDependency(const CNodePtr &getitem, const CNodePtr &final_node, const CNodePtr &cnode, | |||||
| const FuncGraphPtr &func_graph) const; | |||||
| CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | ||||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | ||||
| CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | ||||
| @@ -97,7 +97,6 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{ | |||||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | ||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | ||||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; | {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; | ||||
| } // namespace trans | } // namespace trans | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -182,32 +182,36 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj | |||||
| } | } | ||||
| } | } | ||||
| BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const { | |||||
| SyncData(py_args); | |||||
| auto size = py_args.size(); | |||||
| py::tuple input_args(size - 2); | |||||
| for (size_t i = 0; i < size - 2; ++i) { | |||||
| input_args[i] = py_args[i]; | |||||
| } | |||||
| py::tuple convert_args(py_args.size()); | |||||
| ConvertCTensorToPyTensor(py_args, &convert_args); | |||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(inst); | |||||
| try { | |||||
| MS_LOG(DEBUG) << "Run bprop function start"; | |||||
| inst->NewGraph(hook_, input_args.cast<py::args>()); | |||||
| py::object grads_obj = hook_(*convert_args); | |||||
| py::tuple grads = check_bprop_out(grads_obj, py_args); | |||||
| inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>()); | |||||
| MS_LOG(DEBUG) << "Run bprop function end"; | |||||
| return std::make_shared<PyObjectRef>(grads); | |||||
| } catch (std::exception &bt) { | |||||
| inst->ClearRes(); | |||||
| std::rethrow_exception(std::current_exception()); | |||||
| } | |||||
| } | |||||
| BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | ||||
| py::tuple py_args = ConvertDatatoPyTuple(args); | py::tuple py_args = ConvertDatatoPyTuple(args); | ||||
| bool is_bprop = this->HasAttr(kBpropAttrName); | bool is_bprop = this->HasAttr(kBpropAttrName); | ||||
| if (is_bprop) { | if (is_bprop) { | ||||
| SyncData(py_args); | |||||
| auto size = py_args.size(); | |||||
| py::tuple input_args(size - 2); | |||||
| for (size_t i = 0; i < size - 2; ++i) { | |||||
| input_args[i] = py_args[i]; | |||||
| } | |||||
| py::tuple convert_args(py_args.size()); | |||||
| ConvertCTensorToPyTensor(py_args, &convert_args); | |||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(inst); | |||||
| try { | |||||
| MS_LOG(DEBUG) << "Run bprop function start"; | |||||
| inst->NewGraph(hook_, input_args.cast<py::args>()); | |||||
| py::object grads_obj = hook_(*convert_args); | |||||
| py::tuple grads = check_bprop_out(grads_obj, py_args); | |||||
| inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>()); | |||||
| MS_LOG(DEBUG) << "Run bprop function end"; | |||||
| return std::make_shared<PyObjectRef>(grads); | |||||
| } catch (std::exception &bt) { | |||||
| inst->ClearRes(); | |||||
| std::rethrow_exception(std::current_exception()); | |||||
| } | |||||
| return RunBpropHookFunction(py_args); | |||||
| } | } | ||||
| SyncData(py_args[2]); | SyncData(py_args[2]); | ||||
| bool is_cell = this->HasAttr(kCellHookAttrName); | bool is_cell = this->HasAttr(kCellHookAttrName); | ||||
| @@ -55,6 +55,7 @@ class PrimitivePy : public Primitive { | |||||
| void set_hook(const py::function &hook) { hook_ = hook; } | void set_hook(const py::function &hook) { hook_ = hook; } | ||||
| py::function hook() const { return hook_; } | py::function hook() const { return hook_; } | ||||
| BaseRef RunHookFunction(const VectorRef &args) const override; | BaseRef RunHookFunction(const VectorRef &args) const override; | ||||
| BaseRef RunBpropHookFunction(const py::tuple &py_args) const; | |||||
| BaseRef RunComputeFunction(const VectorRef &args) const override; | BaseRef RunComputeFunction(const VectorRef &args) const override; | ||||
| py::object RunPyComputeFunction(const py::tuple &py_args) const; | py::object RunPyComputeFunction(const py::tuple &py_args) const; | ||||
| bool HasComputeFunction() const; | bool HasComputeFunction() const; | ||||
| @@ -28,29 +28,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| namespace { | namespace { | ||||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||||
| "w_shape[1]", w_shape[1], prim_name); | |||||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||||
| std::vector<int64_t> temp_w; | |||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||||
| "w_shape[2:4]", temp_w, prim_name); | |||||
| std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape, | |||||
| const std::vector<int64_t> &x_shape, const int64_t &out_channel) { | |||||
| auto kernel_size_h = w_shape[2]; | auto kernel_size_h = w_shape[2]; | ||||
| auto kernel_size_w = w_shape[3]; | auto kernel_size_w = w_shape[3]; | ||||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | ||||
| @@ -92,13 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| h_out = floor(h_out); | h_out = floor(h_out); | ||||
| w_out = floor(w_out); | w_out = floor(w_out); | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | |||||
| primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name))); | |||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name()); | |||||
| primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name()))); | |||||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | ||||
| return out_shape; | |||||
| } | |||||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | |||||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||||
| } | |||||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||||
| "w_shape[1]", w_shape[1], prim_name); | |||||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||||
| std::vector<int64_t> temp_w; | |||||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||||
| "w_shape[2:4]", temp_w, prim_name); | |||||
| auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel); | |||||
| if (format == NHWC) { | |||||
| out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]}; | |||||
| } | } | ||||
| return std::make_shared<abstract::Shape>(out_shape); | return std::make_shared<abstract::Shape>(out_shape); | ||||
| } | } | ||||
| @@ -23,32 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto doutput = input_args[0]; | |||||
| auto x_size = input_args[2]; | |||||
| auto x_size_value = x_size->GetValueTrack(); | |||||
| MS_EXCEPTION_IF_NULL(x_size); | |||||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||||
| // infer dtype | |||||
| auto dtype = doutput->BuildType(); | |||||
| if (!dtype->isa<TensorType>()) { | |||||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||||
| } | |||||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||||
| auto element = input_tensor_type->element(); | |||||
| // infer shape | |||||
| auto dout_shape = doutput->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(doutput); | |||||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||||
| void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm, | |||||
| const std::vector<int64_t> &x_size_v) { | |||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | ||||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | ||||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | ||||
| @@ -76,6 +52,34 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co | |||||
| pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | ||||
| } | } | ||||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | primitive->AddAttr(kPadList, MakeValue(pad_list)); | ||||
| } | |||||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto prim_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||||
| for (const auto &item : input_args) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| } | |||||
| auto doutput = input_args[0]; | |||||
| auto x_size = input_args[2]; | |||||
| auto x_size_value = x_size->GetValueTrack(); | |||||
| MS_EXCEPTION_IF_NULL(x_size); | |||||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||||
| // infer dtype | |||||
| auto dtype = doutput->BuildType(); | |||||
| if (!dtype->isa<TensorType>()) { | |||||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||||
| } | |||||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||||
| auto element = input_tensor_type->element(); | |||||
| // infer shape | |||||
| auto dout_shape = doutput->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(doutput); | |||||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||||
| SetPadList(primitive, dout_shape_norm, x_size_v); | |||||
| return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v)); | return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v)); | ||||
| } | } | ||||
| @@ -102,8 +102,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | ||||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | ||||
| auto pad_mode_value = (primitive->GetAttr(kPadMode)); | auto pad_mode_value = (primitive->GetAttr(kPadMode)); | ||||
| PadMode pad_mode = PAD; | |||||
| pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||||
| PadMode pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||||
| auto batch = in_shape[0]; | auto batch = in_shape[0]; | ||||
| auto channel = in_shape[1]; | auto channel = in_shape[1]; | ||||
| auto in_h = in_shape[2]; | auto in_h = in_shape[2]; | ||||