|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """bond class"""
-
- import numpy as np
- import mindspore.common.dtype as mstype
- from mindspore import Tensor, nn
- from mindspore.ops import operations as P
-
-
- class Bond(nn.Cell):
- """bond class"""
-
- def __init__(self, controller, md_info):
- super(Bond, self).__init__()
-
- self.atom_numbers = md_info.atom_numbers
-
- if controller.amber_parm is not None:
- file_path = controller.amber_parm
- self.read_information_from_amberfile(file_path)
-
- self.atom_a = Tensor(np.asarray(self.h_atom_a, np.int32), mstype.int32)
- self.atom_b = Tensor(np.asarray(self.h_atom_b, np.int32), mstype.int32)
- self.bond_k = Tensor(np.asarray(self.h_k, np.float32), mstype.float32)
- self.bond_r0 = Tensor(np.asarray(self.h_r0, np.float32), mstype.float32)
-
- def process1(self, context):
- """process1: read information from amberfile"""
- for idx, val in enumerate(context):
- if idx < len(context) - 1:
- if "%FLAG POINTERS" in val + context[idx + 1] and "%FORMAT(10I8)" in val + context[idx + 1]:
- start_idx = idx + 2
- count = 0
- value = list(map(int, context[start_idx].strip().split()))
- self.bond_with_hydrogen = value[2]
- self.bond_numbers = value[3]
- self.bond_numbers += self.bond_with_hydrogen
- print(self.bond_numbers)
- information = []
- information.extend(value)
- while count < 16:
- start_idx += 1
- value = list(map(int, context[start_idx].strip().split()))
- information.extend(value)
- count += len(value)
- self.bond_type_numbers = information[15]
- print("bond type numbers ", self.bond_type_numbers)
- break
-
- for idx, val in enumerate(context):
- if "%FLAG BOND_FORCE_CONSTANT" in val:
- count = 0
- start_idx = idx
- information = []
- while count < self.bond_type_numbers:
- start_idx += 1
- if "%FORMAT" in context[start_idx]:
- continue
- else:
- value = list(map(float, context[start_idx].strip().split()))
- information.extend(value)
- count += len(value)
- self.bond_type_k = information[:self.bond_type_numbers]
- break
-
- def read_information_from_amberfile(self, file_path):
- """read information from amberfile"""
- file = open(file_path, 'r')
- context = file.readlines()
- file.close()
- self.process1(context)
-
- for idx, val in enumerate(context):
- if "%FLAG BOND_EQUIL_VALUE" in val:
- count = 0
- start_idx = idx
- information = []
- while count < self.bond_type_numbers:
- start_idx += 1
- if "%FORMAT" in context[start_idx]:
- continue
- else:
- value = list(map(float, context[start_idx].strip().split()))
- information.extend(value)
- count += len(value)
- self.bond_type_r = information[:self.bond_type_numbers]
- break
-
- for idx, val in enumerate(context):
- if "%FLAG BONDS_INC_HYDROGEN" in val:
- self.h_atom_a = [0] * self.bond_numbers
- self.h_atom_b = [0] * self.bond_numbers
- self.h_k = [0] * self.bond_numbers
- self.h_r0 = [0] * self.bond_numbers
-
- count = 0
- start_idx = idx
- information = []
- while count < 3 * self.bond_with_hydrogen:
- start_idx += 1
- if "%FORMAT" in context[start_idx]:
- continue
- else:
- value = list(map(int, context[start_idx].strip().split()))
- information.extend(value)
- count += len(value)
-
- for i in range(self.bond_with_hydrogen):
- self.h_atom_a[i] = information[3 * i + 0] / 3
- self.h_atom_b[i] = information[3 * i + 1] / 3
- tmpi = information[3 * i + 2] - 1
- self.h_k[i] = self.bond_type_k[tmpi]
- self.h_r0[i] = self.bond_type_r[tmpi]
- break
-
- for idx, val in enumerate(context):
- if "%FLAG BONDS_WITHOUT_HYDROGEN" in val:
- count = 0
- start_idx = idx
- information = []
- while count < 3 * (self.bond_numbers - self.bond_with_hydrogen):
- start_idx += 1
- if "%FORMAT" in context[start_idx]:
- continue
- else:
- value = list(map(int, context[start_idx].strip().split()))
- information.extend(value)
- count += len(value)
-
- for i in range(self.bond_with_hydrogen, self.bond_numbers):
- self.h_atom_a[i] = information[3 * (i - self.bond_with_hydrogen) + 0] / 3
- self.h_atom_b[i] = information[3 * (i - self.bond_with_hydrogen) + 1] / 3
- tmpi = information[3 * (i - self.bond_with_hydrogen) + 2] - 1
- self.h_k[i] = self.bond_type_k[tmpi]
- self.h_r0[i] = self.bond_type_r[tmpi]
- break
-
- def Bond_Energy(self, uint_crd, uint_dr_to_dr_cof):
- """compute bond energy"""
- self.bond_energy = P.BondEnergy(self.bond_numbers, self.atom_numbers)(uint_crd, uint_dr_to_dr_cof, self.atom_a,
- self.atom_b, self.bond_k, self.bond_r0)
- self.sigma_of_bond_ene = P.ReduceSum()(self.bond_energy)
- return self.sigma_of_bond_ene
-
- def Bond_Force_With_Atom_Energy(self, uint_crd, scaler):
- """compute bond force with atom energy"""
- self.bfatomenergy = P.BondForceWithAtomEnergy(bond_numbers=self.bond_numbers,
- atom_numbers=self.atom_numbers)
- frc, atom_energy = self.bfatomenergy(uint_crd, scaler, self.atom_a, self.atom_b, self.bond_k, self.bond_r0)
- return frc, atom_energy
|