# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import os import pytest import numpy as np import mindspore as ms from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.common.tensor import Tensor import mindspore.context as context import mindspore.communication.management as distributedTool device_num = 2 device_id = int(os.getenv('DEVICE_ID')) rank_id = 0 def setup_module(): global device_num global rank_id np.random.seed(0) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(enable_ir_fusion=True) context.set_context(enable_loop_sink=False) distributedTool.init() device_num = distributedTool.get_group_size() rank_id = distributedTool.get_rank() context.set_auto_parallel_context(device_num=device_num, global_rank=rank_id) def teardown_module(): distributedTool.release() class Onehot(Cell): def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None): super(Onehot, self).__init__() trans_stra = None if strategy: trans_stra = (strategy[0],) self.onehot = P.OneHot().set_strategy(strategy=strategy) self.depth = depth self.on_value = Tensor(on_value, ms.float32) self.off_value = Tensor(off_value, ms.float32) self.transpose = P.Transpose().set_strategy(strategy=trans_stra) self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1))) def construct(self, input, indices): x = self.onehot(indices, self.depth, self.on_value, self.off_value) x = self.transpose(x, (1, 0)) x = self.sub(input, x) return x class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) i += 1 return blocks def generate_data(self, shape): data = np.random.rand(*shape) return data def input_data(self, shape): data = (self.generate_data(shape)*2).astype(np.float32) stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) return Tensor(data), Tensor(datas[rank_id]) def label_data(self, shape, classes): data = (self.generate_data(shape)*(classes-1)).astype(np.int32) stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) return Tensor(data), Tensor(datas[rank_id]) class OneHotFactory: def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None): dataGen = DataGenerator() self.input_full, self.input_part = dataGen.input_data((classes, batch_size)) self.label_full, self.label_part = dataGen.label_data((batch_size,), classes) self.depth = classes self.on_value = on_value self.off_value = off_value self.axis = axis self.strategy = strategy def forward_mindspore_single_impl(self): net = Onehot(axis=self.axis, depth=self.depth, on_value=self.on_value, off_value=self.off_value) out = net(self.input_full, self.label_full) return out def forward_mindspore_parallel_impl(self): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") net = Onehot(axis=self.axis, depth=self.depth, on_value=self.on_value, off_value=self.off_value, strategy=self.strategy) out = net.compile_and_run(self.input_full, self.label_full) return out def forward_cmp(self): out_mindspore_single = self.forward_mindspore_single_impl().asnumpy() context.reset_auto_parallel_context() out_mindspore_parallel = self.forward_mindspore_parallel_impl().asnumpy() context.reset_auto_parallel_context() assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001) def test_reid_onehot_forward_int32_128_depth1024_model_parallel(): fact = OneHotFactory(batch_size=128, classes=1024, on_value=1.000000, off_value=0.000000, axis=-1, strategy=((1, device_num), (), ())) fact.forward_cmp() def test_reid_onehot_forward_int32_1024_depth128_model_parallel(): fact = OneHotFactory(batch_size=1024, classes=128, on_value=1.000000, off_value=0.000000, axis=-1, strategy=((1, device_num), (), ())) fact.forward_cmp()