|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- import numpy as np
-
- import mindspore as ms
- import mindspore.communication.management as distributedTool
- import mindspore.context as context
- from mindspore.common.tensor import Tensor
- from mindspore.nn import Cell
- from mindspore.ops import operations as P
-
- device_num = 2
- device_id = int(os.getenv('DEVICE_ID'))
- rank_id = 0
-
-
- def setup_module():
- global device_num
- global rank_id
- np.random.seed(0)
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- context.set_context(device_id=device_id)
- distributedTool.init()
- device_num = distributedTool.get_group_size()
- rank_id = distributedTool.get_rank()
- context.set_auto_parallel_context(device_num=device_num,
- global_rank=rank_id)
-
-
- def teardown_module():
- distributedTool.release()
-
-
- class Onehot(Cell):
- def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None):
- super(Onehot, self).__init__()
- trans_stra = None
- if strategy:
- trans_stra = (strategy[0],)
- self.onehot = P.OneHot().shard(strategy=strategy)
- self.depth = depth
- self.on_value = Tensor(on_value, ms.float32)
- self.off_value = Tensor(off_value, ms.float32)
- self.transpose = P.Transpose().shard(strategy=trans_stra)
- self.sub = P.Sub().shard(strategy=((1, 1), (1, 1)))
- self.axis = axis
-
- def construct(self, input_, indices):
- x = self.onehot(indices, self.depth, self.on_value, self.off_value)
- x = self.transpose(x, (1, 0))
- x = self.sub(input_, x)
- return x
-
-
- class DataGenerator():
- def get_parallel_blocks(self, input_, strategy):
- blocks = [input_]
- i = 0
- for stra in strategy:
- temp = []
- while blocks:
- block = blocks.pop(0)
- temp.extend(np.split(block, stra, axis=i))
- blocks.extend(temp)
- i += 1
- return blocks
-
- def generate_data(self, shape):
- data = np.random.rand(*shape)
- return data
-
- def input_data(self, shape):
- data = (self.generate_data(shape) * 2).astype(np.float32)
- stra = [1] * len(shape)
- stra[0] = device_num
- datas = self.get_parallel_blocks(data, stra)
- return Tensor(data), Tensor(datas[rank_id])
-
- def label_data(self, shape, classes):
- data = (self.generate_data(shape) * (classes - 1)).astype(np.int32)
- stra = [1] * len(shape)
- stra[0] = device_num
- datas = self.get_parallel_blocks(data, stra)
- return Tensor(data), Tensor(datas[rank_id])
-
-
- class OneHotFactory:
- def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None):
- data_gen = DataGenerator()
- self.input_full, self.input_part = data_gen.input_data((classes, batch_size))
- self.label_full, self.label_part = data_gen.label_data((batch_size,), classes)
- self.depth = classes
- self.on_value = on_value
- self.off_value = off_value
- self.axis = axis
- self.strategy = strategy
-
- def forward_mindspore_single_impl(self):
- net = Onehot(axis=self.axis,
- depth=self.depth,
- on_value=self.on_value,
- off_value=self.off_value)
- out = net(self.input_full, self.label_full)
- return out
-
- def forward_mindspore_parallel_impl(self):
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
- net = Onehot(axis=self.axis,
- depth=self.depth,
- on_value=self.on_value,
- off_value=self.off_value, strategy=self.strategy)
- out = net.compile_and_run(self.input_full, self.label_full)
- return out
-
- def forward_cmp(self):
- out_mindspore_single = self.forward_mindspore_single_impl().asnumpy()
- context.reset_auto_parallel_context()
- out_mindspore_parallel = self.forward_mindspore_parallel_impl().asnumpy()
- context.reset_auto_parallel_context()
- assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001)
-
-
- def test_reid_onehot_forward_int32_128_depth1024_model_parallel():
- fact = OneHotFactory(batch_size=128,
- classes=1024,
- on_value=1.000000,
- off_value=0.000000,
- axis=-1,
- strategy=((1, device_num), (), ()))
- fact.forward_cmp()
-
-
- def test_reid_onehot_forward_int32_1024_depth128_model_parallel():
- fact = OneHotFactory(batch_size=1024,
- classes=128,
- on_value=1.000000,
- off_value=0.000000,
- axis=-1,
- strategy=((1, device_num), (), ()))
- fact.forward_cmp()
|