| @@ -26,6 +26,7 @@ | |||||
| #include "backend/kernel_compiler/common_utils.h" | #include "backend/kernel_compiler/common_utils.h" | ||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" | ||||
| #include "backend/optimizer/graph_kernel/split_umonad.h" | |||||
| #include "backend/optimizer/graph_kernel/substitute_dropout.h" | #include "backend/optimizer/graph_kernel/substitute_dropout.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "mindspore/core/ir/graph_utils.h" | #include "mindspore/core/ir/graph_utils.h" | ||||
| @@ -37,10 +38,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kAssignInputIdx = 1; | |||||
| constexpr size_t kLambInputIdx = 12; | |||||
| std::vector<PrimitivePtr> GetExpandOps() { | std::vector<PrimitivePtr> GetExpandOps() { | ||||
| std::vector<PrimitivePtr> expand_ops = { | std::vector<PrimitivePtr> expand_ops = { | ||||
| prim::kPrimSquare, | prim::kPrimSquare, | ||||
| prim::kPrimGeLUGrad, | prim::kPrimGeLUGrad, | ||||
| prim::kPrimAssignAdd, | |||||
| #if ENABLE_D | #if ENABLE_D | ||||
| prim::kPrimTile, | prim::kPrimTile, | ||||
| prim::kPrimSqrtGrad, | prim::kPrimSqrtGrad, | ||||
| @@ -69,7 +74,6 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimSigmoidCrossEntropyWithLogits, | prim::kPrimSigmoidCrossEntropyWithLogits, | ||||
| prim::kPrimSigmoidCrossEntropyWithLogitsGrad, | prim::kPrimSigmoidCrossEntropyWithLogitsGrad, | ||||
| prim::kPrimSoftmaxCrossEntropyWithLogits, | prim::kPrimSoftmaxCrossEntropyWithLogits, | ||||
| prim::kPrimAssignAdd, | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| const auto &flags = context::GraphKernelFlags::GetInstance(); | const auto &flags = context::GraphKernelFlags::GetInstance(); | ||||
| @@ -167,6 +171,22 @@ AnfNodePtr DefaultExpander::Run(const AnfNodePtr &node) { | |||||
| return graph_kernel_node; | return graph_kernel_node; | ||||
| } | } | ||||
| ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { | |||||
| std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = { | |||||
| {prim::kPrimDropout, std::make_shared<DropoutExpander>()}, | |||||
| {prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)}, | |||||
| {prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)}, | |||||
| {prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambInputIdx)}, | |||||
| }; | |||||
| for (auto &e : expanders) { | |||||
| if (IsPrimitiveCNode(node, e.first)) { | |||||
| return e.second; | |||||
| } | |||||
| } | |||||
| return std::make_shared<DefaultExpander>(); | |||||
| } | |||||
| bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | ||||
| bool changed = false; | bool changed = false; | ||||
| auto todos = TopoSort(func_graph->get_return()); | auto todos = TopoSort(func_graph->get_return()); | ||||
| @@ -192,18 +212,6 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) { | |||||
| return changed; | return changed; | ||||
| } | } | ||||
| ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { | |||||
| std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = { | |||||
| {prim::kPrimDropout, std::make_shared<DropoutExpander>()}, | |||||
| }; | |||||
| for (auto &e : expanders) { | |||||
| if (IsPrimitiveCNode(node, e.first)) { | |||||
| return e.second; | |||||
| } | |||||
| } | |||||
| return std::make_shared<DefaultExpander>(); | |||||
| } | |||||
| bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { | ||||
| expand_ops_ = GetExpandOps(); | expand_ops_ = GetExpandOps(); | ||||
| return DoExpand(func_graph); | return DoExpand(func_graph); | ||||
| @@ -37,7 +37,7 @@ | |||||
| #include "backend/optimizer/graph_kernel/value_graph_binder.h" | #include "backend/optimizer/graph_kernel/value_graph_binder.h" | ||||
| #include "backend/optimizer/graph_kernel/parallel_fusion.h" | #include "backend/optimizer/graph_kernel/parallel_fusion.h" | ||||
| #include "backend/optimizer/graph_kernel/optimize_assign.h" | #include "backend/optimizer/graph_kernel/optimize_assign.h" | ||||
| #include "backend/optimizer/graph_kernel/split_assign.h" | |||||
| #include "backend/optimizer/graph_kernel/split_umonad.h" | |||||
| #include "backend/optimizer/graph_kernel/reorder_ops.h" | #include "backend/optimizer/graph_kernel/reorder_ops.h" | ||||
| #include "backend/optimizer/graph_kernel/update_state_formatter.h" | #include "backend/optimizer/graph_kernel/update_state_formatter.h" | ||||
| #include "backend/optimizer/graph_kernel/axis_normalizer.h" | #include "backend/optimizer/graph_kernel/axis_normalizer.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/optimizer/graph_kernel/split_assign.h" | |||||
| #include "backend/optimizer/graph_kernel/split_umonad.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| @@ -35,31 +35,63 @@ const BaseRef SplitAssign::DefinePattern() const { | |||||
| return VectorRef({v, Xs, Us, UMonad}); | return VectorRef({v, Xs, Us, UMonad}); | ||||
| } | } | ||||
| bool CanSplit(const AnfNodePtr &node) { | |||||
| return IsPrimitiveCNode(node, prim::kPrimAssignAdd) || IsPrimitiveCNode(node, prim::kPrimAssign) || | |||||
| IsPrimitiveCNode(node, prim::kPrimAssignSub); | |||||
| } | |||||
| bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); } | |||||
| const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!CanSplit(node)) return node; | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | CNodePtr cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| CheckCNodeInputSize(cnode, kAssignInputTensorNum); | |||||
| // Get original assign op's abstract and inputs | |||||
| // Get original op's abstract and inputs | |||||
| AbstractBasePtr original_abstract = cnode->abstract()->Clone(); | AbstractBasePtr original_abstract = cnode->abstract()->Clone(); | ||||
| auto original_inputs = cnode->inputs(); | auto original_inputs = cnode->inputs(); | ||||
| int input_node_size = cnode->size() - 1; | |||||
| // Create depend node | // Create depend node | ||||
| AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]}; | |||||
| AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx], | |||||
| original_inputs[input_node_size]}; | |||||
| auto depend_cnode = func_graph->NewCNode(depend_inputs); | auto depend_cnode = func_graph->NewCNode(depend_inputs); | ||||
| depend_cnode->set_abstract(original_inputs[1]->abstract()); | |||||
| depend_cnode->set_abstract(original_inputs[input_idx]->abstract()); | |||||
| depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); | depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>()); | ||||
| // Create new assign node, delete U from inputs. | |||||
| AnfNodePtrList new_assign_inputs = {cnode->input(0), depend_cnode, original_inputs[2]}; | |||||
| auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs); | |||||
| new_assign_cnode->set_abstract(original_abstract); | |||||
| new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||||
| return new_assign_cnode; | |||||
| // Create new node, delete U from inputs. | |||||
| AnfNodePtrList new_inputs = {cnode->input(0)}; | |||||
| for (int i = 1; i < input_node_size; i++) { | |||||
| if (i == input_idx) { | |||||
| new_inputs.push_back(depend_cnode); | |||||
| } else { | |||||
| new_inputs.push_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| auto new_cnode = func_graph->NewCNode(new_inputs); | |||||
| new_cnode->set_abstract(original_abstract); | |||||
| new_cnode->set_kernel_info(cnode->kernel_info_ptr()); | |||||
| return new_cnode; | |||||
| } | |||||
| const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!CanSplit(node)) return node; | |||||
| return ProcessNode(node->func_graph(), node, 1); | |||||
| } | |||||
| AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| bool has_umonad = false; | |||||
| for (unsigned int i = 1; i < cnode->size(); i++) { | |||||
| if (HasAbstractUMonad(cnode->input(i))) { | |||||
| has_umonad = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_umonad) { | |||||
| auto new_node = ProcessNode(node->func_graph(), node, input_idx_); | |||||
| return DefaultExpander::Run(new_node); | |||||
| } | |||||
| return DefaultExpander::Run(node); | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,11 +13,11 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class SplitAssign : public PatternProcessPass { | class SplitAssign : public PatternProcessPass { | ||||
| @@ -27,6 +27,16 @@ class SplitAssign : public PatternProcessPass { | |||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| }; | }; | ||||
| class OpUMonadExpander : public DefaultExpander { | |||||
| public: | |||||
| explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {} | |||||
| ~OpUMonadExpander() = default; | |||||
| AnfNodePtr Run(const AnfNodePtr &node) override; | |||||
| private: | |||||
| int input_idx_; | |||||
| }; | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_UMONAD_H_ | |||||
| @@ -219,7 +219,9 @@ bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, co | |||||
| auto mng = func_graph->manager(); | auto mng = func_graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| for (auto user : mng->node_users()[getitems_[index]]) { | for (auto user : mng->node_users()[getitems_[index]]) { | ||||
| user.first->cast<CNodePtr>()->set_input(user.second, new_node); | |||||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||||
| user.first->cast<CNodePtr>()->set_input(user.second, new_node); | |||||
| } | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -32,26 +32,38 @@ class AssignAdd(nn.Cell): | |||||
| self.add(self.var, y) | self.add(self.var, y) | ||||
| return self.var | return self.var | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_assign_add(): | |||||
| x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) | |||||
| y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| enable_graph_kernel=True, device_target="GPU") | |||||
| def get_output(x2, y2, enable_graph_kernel=False): | |||||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||||
| add = AssignAdd(x2) | add = AssignAdd(x2) | ||||
| result_gk_on_1 = add(y2) | result_gk_on_1 = add(y2) | ||||
| add_2 = AssignAdd(result_gk_on_1) | add_2 = AssignAdd(result_gk_on_1) | ||||
| result_gk_on_2 = add_2(y2) | result_gk_on_2 = add_2(y2) | ||||
| output = [result_gk_on_1, result_gk_on_2] | |||||
| return output | |||||
| def assign_add(): | |||||
| x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) | |||||
| y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) | |||||
| expect = get_output(x2, y2, False) | |||||
| output = get_output(x2, y2, True) | |||||
| e1, e2 = list(expect) | |||||
| o1, o2 = list(output) | |||||
| assert np.allclose(o1.asnumpy(), e1.asnumpy()) | |||||
| assert np.allclose(o2.asnumpy(), e2.asnumpy()) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| enable_graph_kernel=False, device_target="GPU") | |||||
| add_beta = AssignAdd(x2) | |||||
| result_gk_off_1 = add_beta(y2) | |||||
| add_beta_2 = AssignAdd(result_gk_off_1) | |||||
| result_gk_off_2 = add_beta_2(y2) | |||||
| assert (result_gk_on_1.asnumpy() == result_gk_off_1.asnumpy()).all() | |||||
| assert (result_gk_on_2.asnumpy() == result_gk_off_2.asnumpy()).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_assign_add_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| assign_add() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_assign_add_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| assign_add() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -67,6 +68,10 @@ def lamb_apply_optimizer_assign(): | |||||
| assert np.allclose(o2.asnumpy(), e2.asnumpy()) | assert np.allclose(o2.asnumpy(), e2.asnumpy()) | ||||
| assert np.allclose(o3.asnumpy(), e3.asnumpy()) | assert np.allclose(o3.asnumpy(), e3.asnumpy()) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_lamb_apply_optimizer_assign_ascend(): | def test_lamb_apply_optimizer_assign_ascend(): | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| lamb_apply_optimizer_assign() | lamb_apply_optimizer_assign() | ||||