You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pynative_embeddinglookup.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_pynative_embeddinglookup """
  16. import pytest
  17. import numpy as np
  18. import mindspore.ops.operations as op
  19. from mindspore import Tensor, context
  20. from mindspore.nn import Cell
  21. def setup_module():
  22. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  23. class MetaFactory:
  24. def __init__(self):
  25. self.device_target = context.get_context('device_target')
  26. self.rank_size = None
  27. self.device_id = None
  28. self.global_rank_id = None
  29. class OpsFactory(MetaFactory):
  30. def __init__(self, dtype=np.float16):
  31. super().__init__()
  32. self.dtype = dtype
  33. if self.dtype == np.float16:
  34. self.loss = 1e-3
  35. elif self.dtype == np.float32:
  36. self.loss = 1e-4
  37. elif self.dtype == np.float64:
  38. self.loss = 1e-5
  39. else:
  40. self.loss = 0
  41. class EmbeddingLookup(Cell):
  42. def __init__(self, offset):
  43. super().__init__()
  44. self.op = op.EmbeddingLookup()
  45. self.offset = offset
  46. def construct(self, params, indices):
  47. x = self.op(params, indices, self.offset)
  48. return x
  49. class EmbeddingLookupFactory(OpsFactory):
  50. def __init__(self, params_shape, indices_shape, offset=0, low=0, high=2, dtype=np.float32, ids_type=np.int32):
  51. super().__init__(dtype=dtype)
  52. self.input_np = np.random.randn(*params_shape).astype(dtype)
  53. self.indices_np = np.random.randint(low, high, size=indices_shape).astype(ids_type)
  54. self.offset = offset
  55. self.output_grad_np = None
  56. def forward_mindspore_impl(self):
  57. net = EmbeddingLookup(self.offset)
  58. out = net(Tensor(self.input_np), Tensor(self.indices_np))
  59. return out.asnumpy()
  60. @pytest.mark.level0
  61. @pytest.mark.platform_arm_ascend_training
  62. @pytest.mark.platform_x86_ascend_training
  63. @pytest.mark.env_onecard
  64. def test_embeddinglookup_indices_outrange():
  65. fact = EmbeddingLookupFactory(params_shape=(2, 4), indices_shape=(2, 3), low=1, high=3, offset=10, dtype=np.int8)
  66. out = fact.forward_mindspore_impl()
  67. out_expect = np.zeros((2, 3, 4))
  68. np.allclose(out_expect, out)