You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dihedral.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """dihedral class"""
  16. import math
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. from mindspore import Tensor, nn
  20. from mindspore.ops import operations as P
  21. class Dihedral(nn.Cell):
  22. """dihedral class"""
  23. def __init__(self, controller):
  24. super(Dihedral, self).__init__()
  25. self.CONSTANT_Pi = 3.1415926535897932
  26. if controller.amber_parm is not None:
  27. file_path = controller.amber_parm
  28. self.read_information_from_amberfile(file_path)
  29. self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
  30. self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
  31. self.atom_c = Tensor(np.asarray(self.h_atom_c, np.int32), mstype.int32)
  32. self.atom_d = Tensor(np.asarray(self.h_atom_d, np.int32), mstype.int32)
  33. self.pk = Tensor(np.asarray(self.pk, np.float32), mstype.float32)
  34. self.gamc = Tensor(np.asarray(self.gamc, np.float32), mstype.float32)
  35. self.gams = Tensor(np.asarray(self.gams, np.float32), mstype.float32)
  36. self.pn = Tensor(np.asarray(self.pn, np.float32), mstype.float32)
  37. self.ipn = Tensor(np.asarray(self.ipn, np.int32), mstype.int32)
  38. def process1(self, context):
  39. """process1: read information from amberfile"""
  40. for idx, val in enumerate(context):
  41. if idx < len(context) - 1:
  42. if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
  43. start_idx = idx + 2
  44. count = 0
  45. value = list(map(int, context[start_idx].strip().split()))
  46. self.dihedral_with_hydrogen = value[6]
  47. self.dihedral_numbers = value[7]
  48. self.dihedral_numbers += self.dihedral_with_hydrogen
  49. information = []
  50. information.extend(value)
  51. while count < 15:
  52. start_idx += 1
  53. value = list(map(int, context[start_idx].strip().split()))
  54. information.extend(value)
  55. count += len(value)
  56. self.dihedral_type_numbers = information[17]
  57. print("dihedral type numbers ", self.dihedral_type_numbers)
  58. break
  59. self.phase_type = [0] * self.dihedral_type_numbers
  60. self.pk_type = [0] * self.dihedral_type_numbers
  61. self.pn_type = [0] * self.dihedral_type_numbers
  62. for idx, val in enumerate(context):
  63. if "%FLAG DIHEDRAL_FORCE_CONSTANT" in val:
  64. count = 0
  65. start_idx = idx
  66. information = []
  67. while count < self.dihedral_type_numbers:
  68. start_idx += 1
  69. if "%FORMAT" in context[start_idx]:
  70. continue
  71. else:
  72. value = list(map(float, context[start_idx].strip().split()))
  73. information.extend(value)
  74. count += len(value)
  75. self.pk_type = information[:self.dihedral_type_numbers]
  76. break
  77. for idx, val in enumerate(context):
  78. if "%FLAG DIHEDRAL_PHASE" in val:
  79. count = 0
  80. start_idx = idx
  81. information = []
  82. while count < self.dihedral_type_numbers:
  83. start_idx += 1
  84. if "%FORMAT" in context[start_idx]:
  85. continue
  86. else:
  87. value = list(map(float, context[start_idx].strip().split()))
  88. information.extend(value)
  89. count += len(value)
  90. self.phase_type = information[:self.dihedral_type_numbers]
  91. break
  92. for idx, val in enumerate(context):
  93. if "%FLAG DIHEDRAL_PERIODICITY" in val:
  94. count = 0
  95. start_idx = idx
  96. information = []
  97. while count < self.dihedral_type_numbers:
  98. start_idx += 1
  99. if "%FORMAT" in context[start_idx]:
  100. continue
  101. else:
  102. value = list(map(float, context[start_idx].strip().split()))
  103. information.extend(value)
  104. count += len(value)
  105. self.pn_type = information[:self.dihedral_type_numbers]
  106. break
  107. def read_information_from_amberfile(self, file_path):
  108. """read information from amberfile"""
  109. file = open(file_path, 'r')
  110. context = file.readlines()
  111. file.close()
  112. self.process1(context)
  113. self.h_atom_a = [0] * self.dihedral_numbers
  114. self.h_atom_b = [0] * self.dihedral_numbers
  115. self.h_atom_c = [0] * self.dihedral_numbers
  116. self.h_atom_d = [0] * self.dihedral_numbers
  117. self.pk = []
  118. self.gamc = []
  119. self.gams = []
  120. self.pn = []
  121. self.ipn = []
  122. for idx, val in enumerate(context):
  123. if "%FLAG DIHEDRALS_INC_HYDROGEN" in val:
  124. count = 0
  125. start_idx = idx
  126. information = []
  127. while count < 5 * self.dihedral_with_hydrogen:
  128. start_idx += 1
  129. if "%FORMAT" in context[start_idx]:
  130. continue
  131. else:
  132. value = list(map(int, context[start_idx].strip().split()))
  133. information.extend(value)
  134. count += len(value)
  135. for i in range(self.dihedral_with_hydrogen):
  136. self.h_atom_a[i] = information[i * 5 + 0] / 3
  137. self.h_atom_b[i] = information[i * 5 + 1] / 3
  138. self.h_atom_c[i] = information[i * 5 + 2] / 3
  139. self.h_atom_d[i] = abs(information[i * 5 + 3] / 3)
  140. tmpi = information[i * 5 + 4] - 1
  141. self.pk.append(self.pk_type[tmpi])
  142. tmpf = self.phase_type[tmpi]
  143. if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
  144. tmpf = self.CONSTANT_Pi
  145. tmpf2 = math.cos(tmpf)
  146. if abs(tmpf2) < 1e-6:
  147. tmpf2 = 0
  148. self.gamc.append(tmpf2 * self.pk[i])
  149. tmpf2 = math.sin(tmpf)
  150. if abs(tmpf2) < 1e-6:
  151. tmpf2 = 0
  152. self.gams.append(tmpf2 * self.pk[i])
  153. self.pn.append(abs(self.pn_type[tmpi]))
  154. self.ipn.append(int(self.pn[i] + 0.001))
  155. break
  156. for idx, val in enumerate(context):
  157. if "%FLAG DIHEDRALS_WITHOUT_HYDROGEN" in val:
  158. count = 0
  159. start_idx = idx
  160. information = []
  161. while count < 5 * (self.dihedral_numbers - self.dihedral_with_hydrogen):
  162. start_idx += 1
  163. if "%FORMAT" in context[start_idx]:
  164. continue
  165. else:
  166. value = list(map(int, context[start_idx].strip().split()))
  167. information.extend(value)
  168. count += len(value)
  169. for i in range(self.dihedral_with_hydrogen, self.dihedral_numbers):
  170. self.h_atom_a[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 0] / 3
  171. self.h_atom_b[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 1] / 3
  172. self.h_atom_c[i] = information[(i - self.dihedral_with_hydrogen) * 5 + 2] / 3
  173. self.h_atom_d[i] = abs(information[(i - self.dihedral_with_hydrogen) * 5 + 3] / 3)
  174. tmpi = information[(i - self.dihedral_with_hydrogen) * 5 + 4] - 1
  175. self.pk.append(self.pk_type[tmpi])
  176. tmpf = self.phase_type[tmpi]
  177. if abs(tmpf - self.CONSTANT_Pi) <= 0.001:
  178. tmpf = self.CONSTANT_Pi
  179. tmpf2 = math.cos(tmpf)
  180. if abs(tmpf2) < 1e-6:
  181. tmpf2 = 0
  182. self.gamc.append(tmpf2 * self.pk[i])
  183. tmpf2 = math.sin(tmpf)
  184. if abs(tmpf2) < 1e-6:
  185. tmpf2 = 0
  186. self.gams.append(tmpf2 * self.pk[i])
  187. self.pn.append(abs(self.pn_type[tmpi]))
  188. self.ipn.append(int(self.pn[i] + 0.001))
  189. break
  190. for i in range(self.dihedral_numbers):
  191. if self.h_atom_c[i] < 0:
  192. self.h_atom_c[i] *= -1
  193. def Dihedral_Engergy(self, uint_crd, uint_dr_to_dr_cof):
  194. """compute dihedral energy"""
  195. self.dihedral_energy = P.DihedralEnergy(self.dihedral_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
  196. self.atom_b, self.atom_c, self.atom_d, self.ipn,
  197. self.pk, self.gamc, self.gams, self.pn)
  198. self.sigma_of_dihedral_ene = P.ReduceSum()(self.dihedral_energy)
  199. return self.sigma_of_dihedral_ene
  200. def Dihedral_Force_With_Atom_Energy(self, uint_crd, scaler):
  201. """compute dihedral force and atom energy"""
  202. self.dfae = P.DihedralForceWithAtomEnergy(dihedral_numbers=self.dihedral_numbers)
  203. self.frc, self.ene = self.dfae(uint_crd, scaler, self.atom_a, self.atom_b, self.atom_c, self.atom_d,
  204. self.ipn, self.pk, self.gamc, self.gams, self.pn)
  205. return self.frc, self.ene