|
|
|
@@ -0,0 +1,149 @@ |
|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore as ms |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore import context |
|
|
|
from mindspore.common.api import _executor |
|
|
|
from mindspore.context import set_auto_parallel_context |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from tests.ut.python.ops.test_math_ops import VirtualLoss |
|
|
|
|
|
|
|
|
|
|
|
grad_all = C.GradOperation(get_all=True) |
|
|
|
|
|
|
|
|
|
|
|
class NetWithLoss(nn.Cell): |
|
|
|
def __init__(self, network): |
|
|
|
super(NetWithLoss, self).__init__() |
|
|
|
self.loss = VirtualLoss() |
|
|
|
self.network = network |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
predict = self.network(x) |
|
|
|
return self.loss(predict) |
|
|
|
|
|
|
|
|
|
|
|
class GradWrap(nn.Cell): |
|
|
|
def __init__(self, network): |
|
|
|
super(GradWrap, self).__init__() |
|
|
|
self.network = network |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
return grad_all(self.network)(x) |
|
|
|
|
|
|
|
|
|
|
|
def compile_net(net, x): |
|
|
|
net.set_auto_parallel() |
|
|
|
_executor.compile(net, x) |
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5): |
|
|
|
super().__init__() |
|
|
|
self.query_w = Parameter(initializer( |
|
|
|
"normal", [8, 16], ms.float32), name='query') |
|
|
|
self.query = P.MatMul().shard(strategy1) |
|
|
|
|
|
|
|
self.key_w = Parameter(initializer( |
|
|
|
"normal", [8, 16], ms.float32), name='key') |
|
|
|
self.key = P.MatMul().shard(strategy2) |
|
|
|
|
|
|
|
self.value_w = Parameter(initializer( |
|
|
|
"normal", [8, 16], ms.float32), name='value') |
|
|
|
self.value = P.MatMul().shard(strategy3) |
|
|
|
|
|
|
|
self.score = P.MatMul().shard(strategy4) |
|
|
|
self.context = P.MatMul().shard(strategy5) |
|
|
|
self.transpose1 = P.Transpose() |
|
|
|
self.transpose2 = P.Transpose() |
|
|
|
self.relu = P.ReLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
q = self.query(x, self.query_w) |
|
|
|
k = self.key(x, self.key_w) |
|
|
|
v = self.value(x, self.value_w) |
|
|
|
|
|
|
|
k = self.transpose1(k, (1, 0)) |
|
|
|
s = self.score(q, k) |
|
|
|
|
|
|
|
v = self.transpose2(v, (1, 0)) |
|
|
|
c = self.context(v, s) |
|
|
|
out = self.relu(c) |
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def test_self_attention_standalone(): |
|
|
|
set_auto_parallel_context(device_num=8, global_rank=0) |
|
|
|
context.set_auto_parallel_context(parallel_mode="stand_alone") |
|
|
|
net = GradWrap(NetWithLoss( |
|
|
|
Net(None, None, None, None, None))) |
|
|
|
|
|
|
|
x = Tensor(np.ones([32, 8]), dtype=ms.float32) |
|
|
|
|
|
|
|
compile_net(net, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_self_attention_semi(): |
|
|
|
set_auto_parallel_context(device_num=8, global_rank=0) |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") |
|
|
|
|
|
|
|
strategy1 = ((2, 2), (2, 2)) |
|
|
|
strategy2 = ((2, 2), (2, 2)) |
|
|
|
strategy3 = ((2, 2), (2, 2)) |
|
|
|
strategy4 = ((2, 4), (4, 1)) |
|
|
|
strategy5 = ((2, 1), (1, 4)) |
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss( |
|
|
|
Net(strategy1, strategy2, strategy3, strategy4, strategy5))) |
|
|
|
|
|
|
|
x = Tensor(np.ones([32, 8]), dtype=ms.float32) |
|
|
|
|
|
|
|
compile_net(net, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_self_attention_dp(): |
|
|
|
set_auto_parallel_context(device_num=8, global_rank=0) |
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") |
|
|
|
|
|
|
|
strategy1 = ((8, 1), (1, 1)) |
|
|
|
strategy2 = ((8, 1), (1, 1)) |
|
|
|
strategy3 = ((8, 1), (1, 1)) |
|
|
|
strategy4 = ((8, 1), (1, 1)) |
|
|
|
strategy5 = ((8, 1), (1, 1)) |
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss( |
|
|
|
Net(strategy1, strategy2, strategy3, strategy4, strategy5))) |
|
|
|
|
|
|
|
x = Tensor(np.ones([32, 8]), dtype=ms.float32) |
|
|
|
|
|
|
|
compile_net(net, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_self_attention_auto(): |
|
|
|
set_auto_parallel_context(device_num=8, global_rank=0) |
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel") |
|
|
|
net = GradWrap(NetWithLoss( |
|
|
|
Net(None, None, None, None, None))) |
|
|
|
|
|
|
|
x = Tensor(np.ones([32, 8]), dtype=ms.float32) |
|
|
|
|
|
|
|
compile_net(net, x) |