Browse Source

hybrid inplace assign

feature/build-system-rewrite
Yang Jiao 4 years ago
parent
commit
b35cedeb62
7 changed files with 297 additions and 13 deletions
  1. +1
    -1
      akg
  2. +3
    -1
      mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc
  3. +135
    -0
      mindspore/ccsrc/backend/common/pass/insert_assign_for_custom_op.cc
  4. +36
    -0
      mindspore/ccsrc/backend/common/pass/insert_assign_for_custom_op.h
  5. +1
    -0
      mindspore/ccsrc/utils/utils.h
  6. +17
    -0
      mindspore/python/mindspore/ops/operations/custom_ops.py
  7. +104
    -11
      tests/st/ops/graph_kernel/custom/test_custom_akg.py

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit 42f537f5a0163ddbd57d3ec2173d1a51eed8f7a0
Subproject commit a9cbf642063fb1086a93e8bc6be6feb145689817

+ 3
- 1
mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@
#include "backend/common/pass/add_dynamic_shape_attr.h"
#include "backend/common/pass/add_akg_kernel_attrs.h"
#include "backend/common/pass/sparse_process.h"
#include "backend/common/pass/insert_assign_for_custom_op.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"

@@ -118,6 +119,7 @@ void CommonUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
auto pm = std::make_shared<PassManager>("common_unify_mindir_pm");
pm->AddPass(std::make_shared<ConvTransposeToConvBackpropInputPass>());
pm->AddPass(std::make_shared<CustomOpRegInfoToAttr>());
pm->AddPass(std::make_shared<InsertAssignForCustomOp>());
opt->AddPassManager(pm);
(void)opt->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();


+ 135
- 0
mindspore/ccsrc/backend/common/pass/insert_assign_for_custom_op.cc View File

@@ -0,0 +1,135 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/pass/insert_assign_for_custom_op.h"

#include <memory>
#include <vector>
#include <string>
#include <regex>
#include "backend/common/optimizer/helper.h"
#include "backend/common/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
constexpr auto kCustomOutput = 0;
constexpr auto kCustomInput = 1;
constexpr auto kCustomAttrInplaceAssignOutput = "inplace_assign_output";

// Used to find Custom op outputs' inplace assign index
std::vector<std::vector<int64_t>> GetHybridInplaceIndex(const CNodePtr &cnode) {
if (AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType) != kCustomTypeHybrid) {
return {};
}

if (!AnfAlgo::HasNodeAttr(kCustomAttrInplaceAssignOutput, cnode)) {
return {};
}
auto inplace_index_str = AnfAlgo::GetNodeAttr<std::string>(cnode, kCustomAttrInplaceAssignOutput);
std::regex delimiters(" ");
std::vector<std::string> index(
std::sregex_token_iterator(inplace_index_str.begin(), inplace_index_str.end(), delimiters, -1),
std::sregex_token_iterator());
std::vector<std::vector<int64_t>> inplace_index;
std::vector<int64_t> tmp;
for (size_t i = 0; i < index.size(); i++) {
tmp.push_back(std::stol(index[i]));
if (i & 1) {
inplace_index.push_back(tmp);
tmp.clear();
}
}
return inplace_index;
}

CNodePtr InsertAssign(const FuncGraphPtr &func_graph, const AnfNodePtr &src, const CNodePtr &dst) {
// Insert UpdateState, Load and Assign, need mount a UMonad node.
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());

// Insert Assign
AnfNodePtrList assign_inputs = {NewValueNode(prim::kPrimAssign), dst, src, u};
auto assign_cnode = func_graph->NewCNode(assign_inputs);
assign_cnode->set_abstract(dst->abstract());

// Insert UpdateState
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, assign_cnode};
auto update_state_cnode = func_graph->NewCNode(update_state_inputs);
update_state_cnode->set_abstract(kUMonad->ToAbstract());

// Insert Load
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), dst, update_state_cnode};
auto load_cnode = func_graph->NewCNode(load_inputs);
load_cnode->set_abstract(dst->abstract());

return load_cnode;
}

CNodePtr InsertAssignAfterCustom(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto inplace_info = GetHybridInplaceIndex(cnode);
if (inplace_info.size() != 1) return nullptr;
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
if (auto i = LongToSize(inplace_info[0][kCustomInput]); i < input_size) {
return InsertAssign(func_graph, cnode->input(i + 1), cnode);
} else {
return nullptr;
}
}

CNodePtr InsertAssignAfterTupleGetItem(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto input_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input_node);
auto real_input = dyn_cast<CNode>(input_node);
if (real_input == nullptr) {
return nullptr;
}
auto value_ptr = GetValueNode(cnode->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(value_ptr);
auto gt_idx = GetValue<int64_t>(value_ptr);
if (IsPrimitiveCNode(real_input, prim::kPrimCustom)) {
auto inplace_info = GetHybridInplaceIndex(real_input);
for (auto index : inplace_info) {
if (index[kCustomOutput] == gt_idx && index[kCustomInput] >= 0) {
auto custom_input_size = AnfAlgo::GetInputTensorNum(real_input);
if (auto i = LongToSize(index[kCustomInput]); i < custom_input_size) {
return InsertAssign(func_graph, real_input->input(i + 1), cnode);
}
}
}
}
return nullptr;
}

const AnfNodePtr InsertAssignForCustomOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}

if (IsPrimitiveCNode(cnode, prim::kPrimCustom) && visited_.find(cnode) == visited_.end()) {
visited_.insert(cnode);
return InsertAssignAfterCustom(func_graph, cnode);
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) && visited_.find(cnode) == visited_.end()) {
visited_.insert(cnode);
return InsertAssignAfterTupleGetItem(func_graph, cnode);
}

return nullptr;
}
} // namespace opt
} // namespace mindspore

+ 36
- 0
mindspore/ccsrc/backend/common/pass/insert_assign_for_custom_op.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_ASSIGN_FOR_CUSTOM_OP_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_ASSIGN_FOR_CUSTOM_OP_H_
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"

namespace mindspore {
namespace opt {
class InsertAssignForCustomOp : public PatternProcessPass {
public:
explicit InsertAssignForCustomOp(bool multigraph = true)
: PatternProcessPass("insert_assign_for_custom_op", multigraph) {}
~InsertAssignForCustomOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
mutable mindspore::HashSet<CNodePtr> visited_{};
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_ASSIGN_FOR_CUSTOM_OP_H_

+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -531,6 +531,7 @@ constexpr auto kCustomTypeJULIA = "julia";
constexpr auto kCustomTypePyfunc = "pyfunc";
constexpr auto kCustomTypeTbe = "tbe";
constexpr auto kCustomTypeAICPU = "aicpu";
constexpr auto kCustomTypeHybrid = "hybrid";
const std::set<std::string> kCustomTypeAkg = {"ir_builder", "tvm_compute", "hybrid"};

// primal attr key name


+ 17
- 0
mindspore/python/mindspore/ops/operations/custom_ops.py View File

@@ -17,6 +17,7 @@
import os
import inspect
import json
import re
import hashlib
from mindspore import ops
from mindspore import log as logger
@@ -321,6 +322,7 @@ class Custom(ops.PrimitiveWithInfer):
self.func_type = "tvm_compute"
else:
self.func_type = "hybrid"
self._hybrid_func_analyser()
if not self.bprop:
self._hybrid_autodiff()
self.add_prim_attr("func_type", self.func_type)
@@ -653,3 +655,18 @@ class Custom(ops.PrimitiveWithInfer):
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
func_type="akg", bprop=True)
self.bprop = grad_func(op)

def _hybrid_func_analyser(self):
"""analyze hybrid source string and add corresponding attrs."""
args = {val: idx for idx, val in enumerate(list(inspect.signature(self.func).parameters))}
if self.func_source_str.count('return') != 1:
logger.warning("Hybrid function code should have only one 'return' syntax.")
else:
sentences = [s for s in self.func_source_str.split('\n') if s.count("return") == 1]
symbols = re.sub(r"return|\s|\[|\]|\(|\)", "", sentences[-1]).split(',')
inplace_assign_output = [[idx, args[val]] if val in args else [idx, -1]
for idx, val in enumerate(symbols)]

if any(i[1] != -1 for i in inplace_assign_output):
self.add_prim_attr("inplace_assign_output", " ".join(
[str(j) for i in inplace_assign_output for j in i]))

+ 104
- 11
tests/st/ops/graph_kernel/custom/test_custom_akg.py View File

@@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,6 +44,38 @@ def cube(a):
return c


def multioutput(a, b):
c = output_tensor(a.shape, a.dtype)
d = output_tensor(a.shape, a.dtype)
for i0 in range(a.shape[0]):
for i1 in range(a.shape[1]):
c[i0, i1] = a[i0, i1] + b[i0, i1]
d[i0, i1] = a[i0, i1] * b[i0, i1]
return c, d


def custom_inplace_assign_signle_output(a, b):
c = allocate(a.shape, a.dtype, 'local')
for i0 in range(a.shape[0]):
for i1 in range(a.shape[1]):
c[i0, i1] = a[i0, i1] + b[i0, i1]
a[i0, i1] = c[i0, i1] * b[i0, i1]
return a


def custom_inplace_assign_two_outputs(a, b):
c = allocate(a.shape, a.dtype, 'local')
d = output_tensor(b.shape, b.dtype)
for i0 in range(a.shape[0]):
for i1 in range(a.shape[1]):
c[i0, i1] = a[i0, i1] + b[i0, i1]
a[i0, i1] = c[i0, i1] * b[i0, i1]
for j0 in range(b.shape[0]):
for j1 in range(b.shape[1]):
d[j0, j1] = c[j0, j1]
return a, d


class TestHybridTwoInputs(Cell):
"""Net definition"""

@@ -68,6 +100,22 @@ class TestHybridOneInput(Cell):
return self.program(x)


class TestHybridTwoOutputs(Cell):
"""Net definition"""

def __init__(self, func, out_shape, out_dtype):
super(TestHybridTwoOutputs, self).__init__()

self.program = ops.Custom(func, out_shape=out_shape, out_dtype=out_dtype, func_type="akg")
self.add = ops.Add()
self.mul = ops.Mul()

def construct(self, x, y):
res1, res2 = self.program(x, y)
res3 = self.mul(res1, y)
return self.add(res2, res3)


class MatMulNN(Cell):
"""Net definition"""

@@ -130,6 +178,45 @@ def hybrid_pow_autodiff():
raise ValueError("Precision error, compare result: {}".format(compare_res))


def hybrid_multioutput_autodiff():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)
sens = np.random.normal(0, 1, [4, 4]).astype(np.float32)

test = TestHybridTwoOutputs(multioutput, lambda x, _: (x, x), lambda x, _: (x, x))
dx, dy = ops.GradOperation(sens_param=True, get_all=True)(test)(Tensor(input_x), Tensor(input_y), Tensor(sens))
edx = input_y * sens * 2.0
edy = input_x * sens * 2.0 + input_y * sens * 2.0
compare_res = np.allclose(edx, dx.asnumpy(), 0.001, 0.001)
compare_res &= np.allclose(edy, dy.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))


def hybrid_custom_inplace_assign_one_output():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)

test = TestHybridTwoInputs(custom_inplace_assign_signle_output, lambda x, _: x, lambda x, _: x)
output = test(Tensor(input_x), Tensor(input_y))
expect = input_x * input_y + input_y * input_y
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))


def hybrid_custom_inplace_assign_two_outputs():
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float32)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float32)

test = TestHybridTwoOutputs(custom_inplace_assign_two_outputs, lambda x, y: (x, y), lambda x, y: (x, y))
output = test(Tensor(input_x), Tensor(input_y))
expect = input_x * (input_y**2) + input_y**3 + input_x + input_y
compare_res = np.allclose(expect, output.asnumpy(), 0.001, 0.001)
if not compare_res:
raise ValueError("Precision error, compare result: {}".format(compare_res))


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@@ -171,6 +258,9 @@ def test_hybrid_gpu_graph_mode():
hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
hybrid_multioutput_autodiff()
hybrid_custom_inplace_assign_one_output()
hybrid_custom_inplace_assign_two_outputs()


@pytest.mark.level0
@@ -186,20 +276,23 @@ def test_hybrid_gpu_pynative_mode():
hybrid_outer_product()
hybrid_outer_product_autodiff()
hybrid_pow_autodiff()
hybrid_multioutput_autodiff()
hybrid_custom_inplace_assign_one_output()
hybrid_custom_inplace_assign_two_outputs()


v_add_ascend_info = CustomRegOp() \
.input(0, "x", "dynamic") \
.output(0, "y") \
.dtype_format(DataType.None_None, DataType.None_None) \
.target("Ascend") \
v_add_ascend_info = CustomRegOp()\
.input(0, "x", "dynamic")\
.output(0, "y")\
.dtype_format(DataType.None_None, DataType.None_None)\
.target("Ascend")\
.get_op_info()

v_add_gpu_info = CustomRegOp() \
.input(0, "x", "dynamic") \
.output(0, "y") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.target("GPU") \
v_add_gpu_info = CustomRegOp()\
.input(0, "x", "dynamic")\
.output(0, "y")\
.dtype_format(DataType.F16_None, DataType.F16_None)\
.target("GPU")\
.get_op_info()




Loading…
Cancel
Save