|
|
|
@@ -22,8 +22,8 @@ from mindspore import Tensor |
|
|
|
from mindspore import nn
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from virial import ProdVirialSeA
|
|
|
|
from descriptor import DescriptorSeA
|
|
|
|
from .virial import ProdVirialSeA
|
|
|
|
from .descriptor import DescriptorSeA
|
|
|
|
|
|
|
|
natoms = [192, 192, 64, 128]
|
|
|
|
rcut_a = -1
|
|
|
|
@@ -231,11 +231,12 @@ class Network(nn.Cell): |
|
|
|
self.descrpt_se_a = DescriptorSeA()
|
|
|
|
self.process = Processing()
|
|
|
|
self.prod_virial_se_a = ProdVirialSeA()
|
|
|
|
self.prod_force_se_a = P.ProdForceSeA()
|
|
|
|
|
|
|
|
def construct(self, coord, nlist, frames, avg, std, atype):
|
|
|
|
def construct(self, d_coord, d_nlist, frames, avg, std, atype, nlist):
|
|
|
|
"""construct function."""
|
|
|
|
rij, descrpt, descrpt_deriv = \
|
|
|
|
self.descrpt_se_a(coord, nlist, frames, avg, std, atype)
|
|
|
|
self.descrpt_se_a(d_coord, d_nlist, frames, avg, std, atype)
|
|
|
|
# calculate energy and atom_ener
|
|
|
|
atom_ener = self.mdnet(descrpt)
|
|
|
|
energy_raw = atom_ener
|
|
|
|
@@ -247,4 +248,6 @@ class Network(nn.Cell): |
|
|
|
descrpt_deriv_reshape = self.reshape(descrpt_deriv, (-1, natoms[0], ndescrpt, 3))
|
|
|
|
# calculate virial
|
|
|
|
virial = self.prod_virial_se_a(net_deriv_reshape, descrpt_deriv_reshape, rij, nlist)
|
|
|
|
return energy, atom_ener, virial
|
|
|
|
# calculate force
|
|
|
|
force = self.prod_force_se_a(net_deriv_reshape, descrpt_deriv_reshape, nlist)
|
|
|
|
return energy, atom_ener, force, virial
|