From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjinpull/14442/MERGE
| @@ -18,6 +18,8 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| @@ -73,26 +75,7 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node); | |||
| auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node); | |||
| MS_EXCEPTION_IF_NULL(used_cast_node_list); | |||
| std::unordered_map<string, size_t> format_counter; | |||
| for (const auto &node_info : *used_cast_node_list) { | |||
| MS_EXCEPTION_IF_NULL(node_info.first); | |||
| auto cast_out_node = node_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cast_out_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); | |||
| for (size_t index = 0; index < input_num; ++index) { | |||
| if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first != | |||
| cast_node) { | |||
| continue; | |||
| } | |||
| auto format = AnfAlgo::GetInputFormat(cast_out_node, index); | |||
| auto it = format_counter.find(format); | |||
| if (it == format_counter.end()) { | |||
| format_counter[format] = 1; | |||
| } else { | |||
| it->second++; | |||
| } | |||
| } | |||
| } | |||
| std::unordered_map<string, size_t> format_counter = CalculateFormat(used_cast_node_list, cast_node); | |||
| auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0); | |||
| string convert_format = kOpFormat_DEFAULT; | |||
| if (cast_input_format == kOpFormat_DEFAULT) { | |||
| @@ -121,5 +104,33 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr | |||
| SetCastFormat(cast_node, convert_format); | |||
| } | |||
| } | |||
| std::unordered_map<string, size_t> ConvertCastFormat::CalculateFormat( | |||
| const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list, | |||
| const CNodePtr &cast_node) const { | |||
| MS_EXCEPTION_IF_NULL(used_cast_node_list); | |||
| MS_EXCEPTION_IF_NULL(cast_node); | |||
| std::unordered_map<string, size_t> format_counter; | |||
| for (const auto &node_info : *used_cast_node_list) { | |||
| MS_EXCEPTION_IF_NULL(node_info.first); | |||
| auto cast_out_node = node_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cast_out_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); | |||
| for (size_t index = 0; index < input_num; ++index) { | |||
| if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first != | |||
| cast_node) { | |||
| continue; | |||
| } | |||
| auto format = AnfAlgo::GetInputFormat(cast_out_node, index); | |||
| auto it = format_counter.find(format); | |||
| if (it == format_counter.end()) { | |||
| format_counter[format] = 1; | |||
| } else { | |||
| it->second++; | |||
| } | |||
| } | |||
| } | |||
| return format_counter; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,10 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| @@ -30,6 +34,9 @@ class ConvertCastFormat : public PatternProcessPass { | |||
| const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| std::unordered_map<string, size_t> CalculateFormat( | |||
| const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list, | |||
| const CNodePtr &cast_node) const; | |||
| void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const; | |||
| void SetCastFormat(const CNodePtr &cast_node, const string &format) const; | |||
| }; | |||
| @@ -143,18 +143,22 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_ | |||
| } | |||
| // insert depend | |||
| if (origin_format != cur_format || origin_type != cur_type) { | |||
| std::vector<AnfNodePtr> depend_nodes; | |||
| if (get_item.get() != nullptr) { | |||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; | |||
| } else { | |||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node}; | |||
| } | |||
| final_node = func_graph->NewCNode(depend_nodes); | |||
| final_node = MakeDependency(get_item, final_node, cnode, func_graph); | |||
| MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); | |||
| } | |||
| return final_node; | |||
| } | |||
| CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, | |||
| const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { | |||
| std::vector<AnfNodePtr> depend_nodes; | |||
| if (get_item != nullptr) { | |||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node}; | |||
| } else { | |||
| depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node}; | |||
| } | |||
| return func_graph->NewCNode(depend_nodes); | |||
| } | |||
| CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| @@ -33,6 +33,8 @@ class DealRefTransAndCast : public TransDataSplit { | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| CNodePtr MakeDependency(const CNodePtr &getitem, const CNodePtr &final_node, const CNodePtr &cnode, | |||
| const FuncGraphPtr &func_graph) const; | |||
| CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||
| CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| @@ -97,7 +97,6 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{ | |||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; | |||
| } // namespace trans | |||
| } // namespace mindspore | |||
| @@ -182,32 +182,36 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj | |||
| } | |||
| } | |||
| BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const { | |||
| SyncData(py_args); | |||
| auto size = py_args.size(); | |||
| py::tuple input_args(size - 2); | |||
| for (size_t i = 0; i < size - 2; ++i) { | |||
| input_args[i] = py_args[i]; | |||
| } | |||
| py::tuple convert_args(py_args.size()); | |||
| ConvertCTensorToPyTensor(py_args, &convert_args); | |||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(inst); | |||
| try { | |||
| MS_LOG(DEBUG) << "Run bprop function start"; | |||
| inst->NewGraph(hook_, input_args.cast<py::args>()); | |||
| py::object grads_obj = hook_(*convert_args); | |||
| py::tuple grads = check_bprop_out(grads_obj, py_args); | |||
| inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>()); | |||
| MS_LOG(DEBUG) << "Run bprop function end"; | |||
| return std::make_shared<PyObjectRef>(grads); | |||
| } catch (std::exception &bt) { | |||
| inst->ClearRes(); | |||
| std::rethrow_exception(std::current_exception()); | |||
| } | |||
| } | |||
| BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | |||
| py::tuple py_args = ConvertDatatoPyTuple(args); | |||
| bool is_bprop = this->HasAttr(kBpropAttrName); | |||
| if (is_bprop) { | |||
| SyncData(py_args); | |||
| auto size = py_args.size(); | |||
| py::tuple input_args(size - 2); | |||
| for (size_t i = 0; i < size - 2; ++i) { | |||
| input_args[i] = py_args[i]; | |||
| } | |||
| py::tuple convert_args(py_args.size()); | |||
| ConvertCTensorToPyTensor(py_args, &convert_args); | |||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(inst); | |||
| try { | |||
| MS_LOG(DEBUG) << "Run bprop function start"; | |||
| inst->NewGraph(hook_, input_args.cast<py::args>()); | |||
| py::object grads_obj = hook_(*convert_args); | |||
| py::tuple grads = check_bprop_out(grads_obj, py_args); | |||
| inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>()); | |||
| MS_LOG(DEBUG) << "Run bprop function end"; | |||
| return std::make_shared<PyObjectRef>(grads); | |||
| } catch (std::exception &bt) { | |||
| inst->ClearRes(); | |||
| std::rethrow_exception(std::current_exception()); | |||
| } | |||
| return RunBpropHookFunction(py_args); | |||
| } | |||
| SyncData(py_args[2]); | |||
| bool is_cell = this->HasAttr(kCellHookAttrName); | |||
| @@ -55,6 +55,7 @@ class PrimitivePy : public Primitive { | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| BaseRef RunHookFunction(const VectorRef &args) const override; | |||
| BaseRef RunBpropHookFunction(const py::tuple &py_args) const; | |||
| BaseRef RunComputeFunction(const VectorRef &args) const override; | |||
| py::object RunPyComputeFunction(const py::tuple &py_args) const; | |||
| bool HasComputeFunction() const; | |||
| @@ -28,29 +28,8 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||
| if (format == NHWC) { | |||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||
| "w_shape[1]", w_shape[1], prim_name); | |||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||
| std::vector<int64_t> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||
| "w_shape[2:4]", temp_w, prim_name); | |||
| std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape, | |||
| const std::vector<int64_t> &x_shape, const int64_t &out_channel) { | |||
| auto kernel_size_h = w_shape[2]; | |||
| auto kernel_size_w = w_shape[3]; | |||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| @@ -92,13 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| h_out = floor(h_out); | |||
| w_out = floor(w_out); | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | |||
| primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name))); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name()); | |||
| primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name()))); | |||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | |||
| return out_shape; | |||
| } | |||
| abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | |||
| auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); | |||
| auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); | |||
| if (format == NHWC) { | |||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | |||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||
| w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual, | |||
| "w_shape[1]", w_shape[1], prim_name); | |||
| auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel)); | |||
| CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); | |||
| std::vector<int64_t> temp_w; | |||
| std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); | |||
| CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual, | |||
| "w_shape[2:4]", temp_w, prim_name); | |||
| auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel); | |||
| if (format == NHWC) { | |||
| out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]}; | |||
| } | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| @@ -23,32 +23,8 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto doutput = input_args[0]; | |||
| auto x_size = input_args[2]; | |||
| auto x_size_value = x_size->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(x_size); | |||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||
| // infer dtype | |||
| auto dtype = doutput->BuildType(); | |||
| if (!dtype->isa<TensorType>()) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||
| } | |||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||
| auto element = input_tensor_type->element(); | |||
| // infer shape | |||
| auto dout_shape = doutput->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(doutput); | |||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||
| void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm, | |||
| const std::vector<int64_t> &x_size_v) { | |||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||
| auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride)); | |||
| @@ -76,6 +52,34 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co | |||
| pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); | |||
| } | |||
| primitive->AddAttr(kPadList, MakeValue(pad_list)); | |||
| } | |||
| AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto doutput = input_args[0]; | |||
| auto x_size = input_args[2]; | |||
| auto x_size_value = x_size->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(x_size); | |||
| auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value); | |||
| // infer dtype | |||
| auto dtype = doutput->BuildType(); | |||
| if (!dtype->isa<TensorType>()) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); | |||
| } | |||
| auto input_tensor_type = dtype->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tensor_type); | |||
| auto element = input_tensor_type->element(); | |||
| // infer shape | |||
| auto dout_shape = doutput->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(doutput); | |||
| auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>(); | |||
| auto dout_shape_norm = dout_shapeptr->shape(); | |||
| SetPadList(primitive, dout_shape_norm, x_size_v); | |||
| return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v)); | |||
| } | |||
| @@ -102,8 +102,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | |||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||
| auto pad_mode_value = (primitive->GetAttr(kPadMode)); | |||
| PadMode pad_mode = PAD; | |||
| pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||
| PadMode pad_mode = PadMode(GetValue<int64_t>(pad_mode_value)); | |||
| auto batch = in_shape[0]; | |||
| auto channel = in_shape[1]; | |||
| auto in_h = in_shape[2]; | |||