You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gkdropout.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for GkDropout"""
  16. from ._utils import Expander, ExpanderInfoValidator as VLD
  17. @VLD.check_all_formats_same
  18. @VLD.check_attrs('keep_prob')
  19. class GkDropout(Expander):
  20. """GkDropout expander"""
  21. def _expand(self, graph_builder):
  22. input_x, input_mask = self.inputs
  23. keep_prob = self.attrs['keep_prob']
  24. r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
  25. keep_prob = graph_builder.value(input_x.dtype, keep_prob)
  26. if input_mask.dtype != input_x.dtype:
  27. input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
  28. mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type
  29. mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
  30. # compute result
  31. result = graph_builder.emit('Mul', [r_keep_prob, input_x])
  32. result = graph_builder.emit('Mul', [result, mask])
  33. return result, mask