You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_selected_ops.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ resolve ops """
  16. from mindspore.ops.op_selector import new_ops_selector
  17. op_selector = new_ops_selector(
  18. "mindspore.ops.operations", "mindspore.nn.graph_kernels")
  19. opt_selector = new_ops_selector(
  20. "mindspore.nn.optim", "mindspore.nn.graph_kernels")
  21. nn_selector = new_ops_selector(
  22. "mindspore.nn", "mindspore.nn.graph_kernels")
  23. @nn_selector
  24. class BatchNorm2d:
  25. def __call__(self, *args):
  26. pass
  27. @op_selector
  28. class ReLU:
  29. def __call__(self, *args):
  30. pass
  31. @op_selector
  32. class ReduceMean:
  33. def __call__(self, *args):
  34. pass
  35. @op_selector
  36. class BiasAdd:
  37. def __call__(self, *args):
  38. pass
  39. @op_selector
  40. class FusedBatchNorm:
  41. def __call__(self, *args):
  42. pass
  43. @op_selector
  44. class ApplyMomentum:
  45. def __call__(self, *args):
  46. pass
  47. @op_selector
  48. class SoftmaxCrossEntropyWithLogits:
  49. def __call__(self, *args):
  50. pass
  51. @op_selector
  52. class LogSoftmax:
  53. def __call__(self, *args):
  54. pass
  55. @op_selector
  56. class Tanh:
  57. def __call__(self, *args):
  58. pass
  59. @op_selector
  60. class Gelu:
  61. def __call__(self, *args):
  62. pass
  63. @op_selector
  64. class LayerNorm:
  65. def __call__(self, *args):
  66. pass
  67. @op_selector
  68. class Softmax:
  69. def __call__(self, *args):
  70. pass
  71. @op_selector
  72. class LambUpdateWithLR:
  73. def __call__(self, *args):
  74. pass
  75. @op_selector
  76. class LambNextMV:
  77. def __call__(self, *args):
  78. pass