|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import akg.tvm
- from akg.tvm import build_module
-
-
- def test_copy_prop_case0():
- '''
- B = A
- C = B * 2
-
- ==>
-
- C = A * 2
- '''
-
- shape = (8, 128)
- dtype = "float32"
- A = akg.tvm.placeholder(shape, name="input", dtype=dtype)
-
- B = akg.tvm.compute(shape, lambda *indices: A(*indices), name="B")
- B1 = akg.tvm.compute(shape, lambda *indices: B(*indices), name="B1")
- C = akg.tvm.compute(shape, lambda *indices: B1(*indices) * akg.tvm.const(2.0, dtype), name="C")
-
- s = akg.tvm.create_schedule(C.op)
-
- args = [A, C]
- binds, _ = build_module.get_binds(args)
- bounds = akg.tvm.schedule.InferBound(s)
- stmt = akg.tvm.schedule.ScheduleOps(s, bounds)
-
- stmt = akg.tvm.ir_pass.CopyPropagation(stmt, binds)
-
- success = [True, True, True]
-
- def verify(n):
- if isinstance(n, akg.tvm.stmt.Realize):
- if n.func.name == "B1":
- success[0] = False
- if isinstance(n, akg.tvm.stmt.AttrStmt):
- if n.node.name == "B1":
- success[1] = False
- if isinstance(n, akg.tvm.expr.Call):
- if n.name == "B1":
- success[2] = False
-
- akg.tvm.ir_pass.PostOrderVisit(stmt, verify)
-
- assert(success[0] == True)
- assert(success[1] == True)
- assert(success[2] == True)
-
-
- def test_copy_prop_case1():
- '''
- B = A
- C = B * 2
- B = C
-
- == >
-
- C = A * 2
- B = C
- '''
-
- shape = (8, 128)
- dtype = "float32"
- A = akg.tvm.placeholder(shape, name="input", dtype=dtype)
-
- B = akg.tvm.compute(shape, lambda *indices: A(*indices), name="B")
- B1 = akg.tvm.compute(shape, lambda *indices: B(*indices), name="B1")
- C = akg.tvm.compute(shape, lambda *indices: B1(*indices) * akg.tvm.const(2.0, dtype), name="C")
- B = akg.tvm.compute(shape, lambda *indices: C(*indices), name='B')
-
- s = akg.tvm.create_schedule(B.op)
- args = [A, C]
- binds, _ = build_module.get_binds(args)
- bounds = akg.tvm.schedule.InferBound(s)
- stmt = akg.tvm.schedule.ScheduleOps(s, bounds)
-
- stmt = akg.tvm.ir_pass.CopyPropagation(stmt, binds)
-
- success = [True, True, True]
-
- def verify(n):
- if isinstance(n, akg.tvm.stmt.Realize):
- if n.func.name == "B1":
- success[0] = False
- if isinstance(n, akg.tvm.stmt.AttrStmt):
- if n.node.name == "B1":
- success[1] = False
- if isinstance(n, akg.tvm.expr.Call):
- if n.name == "B1":
- success[2] = False
-
- akg.tvm.ir_pass.PostOrderVisit(stmt, verify)
-
- assert(success[0] == True)
- assert(success[1] == True)
- assert(success[2] == True)
-
-
- if __name__ == '__main__':
- test_copy_prop_case0()
- test_copy_prop_case1()
|