| @@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ | |||||
| MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; | MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; | ||||
| for (size_t i = 0; i < inputs_static.size(); i++) { | for (size_t i = 0; i < inputs_static.size(); i++) { | ||||
| inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); | inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); | ||||
| inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type()); | |||||
| } | } | ||||
| for (size_t j = 0; j < outputs_static.size(); j++) { | for (size_t j = 0; j < outputs_static.size(); j++) { | ||||
| outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); | outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); | ||||
| outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type()); | |||||
| } | } | ||||
| op_info_new_ptr->set_inputs_ptr(inputs_dyn); | op_info_new_ptr->set_inputs_ptr(inputs_dyn); | ||||
| op_info_new_ptr->set_outputs_ptr(outputs_dyn); | op_info_new_ptr->set_outputs_ptr(outputs_dyn); | ||||
| @@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ | |||||
| op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); | op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); | ||||
| } | } | ||||
| bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||||
| for (const auto &c : reshape_type_str) { | |||||
| switch (c) { | |||||
| case 'N': | |||||
| reshape_type_vec->push_back(kernel::N); | |||||
| break; | |||||
| case 'C': | |||||
| reshape_type_vec->push_back(kernel::C); | |||||
| break; | |||||
| case 'H': | |||||
| reshape_type_vec->push_back(kernel::H); | |||||
| break; | |||||
| case 'W': | |||||
| reshape_type_vec->push_back(kernel::W); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num, | bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num, | ||||
| size_t builder_idex, const std::vector<int> &dyn_input_sizes, | size_t builder_idex, const std::vector<int> &dyn_input_sizes, | ||||
| const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { | ||||
| @@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||||
| MS_EXCEPTION_IF_NULL(inputs[0]); | MS_EXCEPTION_IF_NULL(inputs[0]); | ||||
| size_t kernel_info_cnt = inputs[0]->dtypes().size(); | size_t kernel_info_cnt = inputs[0]->dtypes().size(); | ||||
| std::vector<std::vector<Axis>> reshape_types; | |||||
| for (const auto &input : inputs) { | for (const auto &input : inputs) { | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| std::string param_type = input->param_type(); | std::string param_type = input->param_type(); | ||||
| @@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp | |||||
| inputs_format.push_back(formats[builder_idex]); | inputs_format.push_back(formats[builder_idex]); | ||||
| } | } | ||||
| } | } | ||||
| std::vector<Axis> reshape_type; | |||||
| if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { | |||||
| return false; | |||||
| } | |||||
| reshape_types.push_back(reshape_type); | |||||
| } | } | ||||
| builder->SetInputReshapeType(reshape_types); | |||||
| builder->SetInputsDeviceType(inputs_device_type); | builder->SetInputsDeviceType(inputs_device_type); | ||||
| builder->SetInputsFormat(inputs_format); | builder->SetInputsFormat(inputs_format); | ||||
| return true; | return true; | ||||
| @@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||||
| MS_EXCEPTION_IF_NULL(outputs[0]); | MS_EXCEPTION_IF_NULL(outputs[0]); | ||||
| size_t kernel_info_cnt = outputs[0]->dtypes().size(); | size_t kernel_info_cnt = outputs[0]->dtypes().size(); | ||||
| std::vector<std::vector<Axis>> reshape_types; | |||||
| for (const auto &output : outputs) { | for (const auto &output : outputs) { | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| if (output_idx >= real_output_num) { | if (output_idx >= real_output_num) { | ||||
| @@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou | |||||
| outputs_format.push_back(formats[builder_idex]); | outputs_format.push_back(formats[builder_idex]); | ||||
| output_idx++; | output_idx++; | ||||
| } | } | ||||
| std::vector<Axis> reshape_type; | |||||
| if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { | |||||
| return false; | |||||
| } | |||||
| reshape_types.push_back(reshape_type); | |||||
| } | } | ||||
| builder->SetOutputReshapeType(reshape_types); | |||||
| builder->SetOutputsFormat(outputs_format); | builder->SetOutputsFormat(outputs_format); | ||||
| builder->SetOutputsDeviceType(outputs_device_type); | builder->SetOutputsDeviceType(outputs_device_type); | ||||
| return true; | return true; | ||||
| @@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for | |||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, | ||||
| kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | ||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | ||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; | |||||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; | |||||
| // if format is default, it remarkes support all format | // if format is default, it remarkes support all format | ||||
| if (kOpFormatList.find(format) == kOpFormatList.end()) { | if (kOpFormatList.find(format) == kOpFormatList.end()) { | ||||
| @@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for | |||||
| if (shape.empty()) { | if (shape.empty()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if (shape.size() > kShapeSupportFormatMap.size()) { | |||||
| if (shape.size() > kShape4dDims) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { | |||||
| return true; | |||||
| if (format == kOpFormat_FRAC_NZ && shape.size() < 2) { | |||||
| return false; | |||||
| } | } | ||||
| return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); | |||||
| return true; | |||||
| } | } | ||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| @@ -55,12 +55,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| trans_inputs.push_back(input); | trans_inputs.push_back(input); | ||||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | ||||
| MS_EXCEPTION_IF_NULL(trans_node); | MS_EXCEPTION_IF_NULL(trans_node); | ||||
| std::vector<kernel::Axis> padding_axis; | |||||
| if (AnfAlgo::IsRealKernel(input)) { | |||||
| padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||||
| } else { | |||||
| padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0); | |||||
| } | |||||
| if (need_padding) { | if (need_padding) { | ||||
| // if need padding we should set the transdata node's shape to the padding shape | // if need padding we should set the transdata node's shape to the padding shape | ||||
| AnfAlgo::SetOutputInferTypeAndShape( | |||||
| {AnfAlgo::GetOutputInferDataType(input, 0)}, | |||||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, | |||||
| trans_node.get()); | |||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, | |||||
| trans_node.get()); | |||||
| } else { | } else { | ||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | ||||
| {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | ||||
| @@ -194,8 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | input_node = AnfAlgo::GetInputNode(cnode, insert_index); | ||||
| } | } | ||||
| bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && | |||||
| op_name == kTransDataOpName); | |||||
| bool need_padding = false; | |||||
| if (is_insert_input) { | |||||
| need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && | |||||
| op_name == kTransDataOpName); | |||||
| } else { | |||||
| need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && | |||||
| op_name == kTransDataOpName); | |||||
| } | |||||
| if (!need_padding) { | if (!need_padding) { | ||||
| // don't need padding insert transdata only | // don't need padding insert transdata only | ||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); | ||||
| @@ -86,7 +86,6 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||||
| AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); | AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); | ||||
| (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); | (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| const BaseRef BatchNormGradSplit::DefinePattern() const { | const BaseRef BatchNormGradSplit::DefinePattern() const { | ||||
| VarPtr Xs = std::make_shared<SeqVar>(); | VarPtr Xs = std::make_shared<SeqVar>(); | ||||
| @@ -344,7 +344,7 @@ bool IsNopNode(const AnfNodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool IsAllNopNode(session::KernelGraph *const graph) { | |||||
| bool IsAllNopNode(const session::KernelGraph *const graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto execution_order = graph->execution_order(); | auto execution_order = graph->execution_order(); | ||||
| for (auto &cnode : execution_order) { | for (auto &cnode : execution_order) { | ||||
| @@ -347,6 +347,11 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n | |||||
| return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); | return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); | ||||
| } | } | ||||
| std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); | |||||
| return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); | |||||
| } | |||||
| std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { | std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| abstract::BaseShapePtr base_shape = node->Shape(); | abstract::BaseShapePtr base_shape = node->Shape(); | ||||
| @@ -95,6 +95,8 @@ class AnfRuntimeAlgorithm { | |||||
| static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); | static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); | ||||
| // get output format from prev node,input_index is the input index of current node related to prev node | // get output format from prev node,input_index is the input index of current node related to prev node | ||||
| static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); | static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); | ||||
| // get reshape_type of from the output of input node. | |||||
| static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); | |||||
| // get output shapes inferred by ME from input nodes. | // get output shapes inferred by ME from input nodes. | ||||
| static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); | static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); | ||||
| // get input shapes inferred by ME from input nodes. | // get input shapes inferred by ME from input nodes. | ||||
| @@ -204,6 +204,7 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ"; | |||||
| constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; | ||||
| constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | ||||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | ||||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | |||||
| const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, | const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, | ||||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | ||||
| kOpFormat_C1HWNCoC0}; | kOpFormat_C1HWNCoC0}; | ||||
| @@ -225,8 +226,9 @@ const std::set<std::string> kOptOperatorSet = { | |||||
| kApplyRMSPropOpName, | kApplyRMSPropOpName, | ||||
| }; | }; | ||||
| const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; | |||||
| const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, | |||||
| kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, | |||||
| kOpFormat_FRACTAL_Z_C04}; | |||||
| static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | ||||
| if (access(file_name.c_str(), F_OK) != 0) { | if (access(file_name.c_str(), F_OK) != 0) { | ||||
| @@ -23,7 +23,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("bn_training_reduce") \ | .kernel_name("bn_training_reduce") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(0, "x", False, "required", "all", reshape_type="NC") \ | |||||
| .output(0, "sum", False, "required", "all") \ | .output(0, "sum", False, "required", "all") \ | ||||
| .output(1, "square_sum", False, "required", "all") \ | .output(1, "square_sum", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| @@ -24,14 +24,14 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ | |||||
| .kernel_name("bn_training_reduce_grad") \ | .kernel_name("bn_training_reduce_grad") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("epsilon", "optional", "float", "all") \ | .attr("epsilon", "optional", "float", "all") \ | ||||
| .input(0, "grads", False, "required", "all") \ | |||||
| .input(1, "x_norm", False, "required", "all") \ | |||||
| .input(0, "grads", False, "required", "all", reshape_type="NC") \ | |||||
| .input(1, "x_norm", False, "required", "all", reshape_type="NC") \ | |||||
| .input(2, "diff_scale", False, "required", "all") \ | .input(2, "diff_scale", False, "required", "all") \ | ||||
| .input(3, "diff_offset", False, "required", "all") \ | .input(3, "diff_offset", False, "required", "all") \ | ||||
| .input(4, "scale", False, "required", "all") \ | .input(4, "scale", False, "required", "all") \ | ||||
| .input(5, "batch_mean", False, "required", "all") \ | .input(5, "batch_mean", False, "required", "all") \ | ||||
| .input(6, "batch_variance", False, "required", "all") \ | .input(6, "batch_variance", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| @@ -26,14 +26,14 @@ bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \ | |||||
| .attr("factor", "optional", "float", "all") \ | .attr("factor", "optional", "float", "all") \ | ||||
| .attr("epsilon", "optional", "float", "all") \ | .attr("epsilon", "optional", "float", "all") \ | ||||
| .attr("isRef", "optional", "bool", "all", "true") \ | .attr("isRef", "optional", "bool", "all", "true") \ | ||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(0, "x", False, "required", "all", reshape_type="NC") \ | |||||
| .input(1, "sum", False, "required", "all") \ | .input(1, "sum", False, "required", "all") \ | ||||
| .input(2, "square_sum", False, "required", "all") \ | .input(2, "square_sum", False, "required", "all") \ | ||||
| .input(3, "scale", False, "required", "all") \ | .input(3, "scale", False, "required", "all") \ | ||||
| .input(4, "offset", False, "required", "all") \ | .input(4, "offset", False, "required", "all") \ | ||||
| .input(5, "mean", False, "required", "all") \ | .input(5, "mean", False, "required", "all") \ | ||||
| .input(6, "variance", False, "required", "all") \ | .input(6, "variance", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all", reshape_type="NC") \ | |||||
| .output(1, "mean", False, "required", "all") \ | .output(1, "mean", False, "required", "all") \ | ||||
| .output(2, "variance", False, "required", "all") \ | .output(2, "variance", False, "required", "all") \ | ||||
| .output(3, "batch_mean", False, "required", "all") \ | .output(3, "batch_mean", False, "required", "all") \ | ||||
| @@ -24,8 +24,8 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ | |||||
| .kernel_name("bn_training_update_grad") \ | .kernel_name("bn_training_update_grad") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("epsilon", "optional", "float", "all") \ | .attr("epsilon", "optional", "float", "all") \ | ||||
| .input(0, "grads", False, "required", "all") \ | |||||
| .input(1, "x", False, "required", "all") \ | |||||
| .input(0, "grads", False, "required", "all", reshape_type="NC") \ | |||||
| .input(1, "x", False, "required", "all", reshape_type="NC") \ | |||||
| .input(2, "batch_mean", False, "required", "all") \ | .input(2, "batch_mean", False, "required", "all") \ | ||||
| .input(3, "batch_variance", False, "required", "all") \ | .input(3, "batch_variance", False, "required", "all") \ | ||||
| .output(0, "diff_scale", False, "required", "all") \ | .output(0, "diff_scale", False, "required", "all") \ | ||||