|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ..._checkparam import Rel |
|
|
|
|
|
|
|
|
|
|
|
class BondForce(PrimitiveWithInfer): |
|
|
|
@@ -50,12 +51,30 @@ class BondForce(PrimitiveWithInfer): |
|
|
|
``GPU`` |
|
|
|
Examples: |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, bond_numbers): |
|
|
|
self.bond_numbers = bond_numbers |
|
|
|
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'], |
|
|
|
outputs=['frc_f']) |
|
|
|
self.add_prim_attr('bond_numbers', self.bond_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape): |
|
|
|
cls_name = self.name |
|
|
|
# N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) |
|
|
|
return uint_crd_f_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler_f_type', scaler_f_type, [mstype.float32], self.name) |
|
|
|
@@ -81,6 +100,11 @@ class BondEnergy(PrimitiveWithInfer): |
|
|
|
Inputs: |
|
|
|
Same as operator BondForce(). |
|
|
|
|
|
|
|
.. math:: |
|
|
|
|
|
|
|
dr = (x_1-x_2, y_1-y_2, z_1-z_2) |
|
|
|
E = k*(|dr| - r_0)^2 |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **bond_ene** (Tensor, float32) - [M, 1], the harmonic potential energy |
|
|
|
for each bond. |
|
|
|
@@ -89,12 +113,31 @@ class BondEnergy(PrimitiveWithInfer): |
|
|
|
``GPU`` |
|
|
|
Examples: |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, bond_numbers): |
|
|
|
self.bond_numbers = bond_numbers |
|
|
|
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'], |
|
|
|
outputs=['bond_ene']) |
|
|
|
self.add_prim_attr('bond_numbers', self.bond_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape): |
|
|
|
cls_name = self.name |
|
|
|
# N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) |
|
|
|
|
|
|
|
return bond_k_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler_f_type', scaler_f_type, [mstype.float32], self.name) |
|
|
|
@@ -125,12 +168,30 @@ class BondAtomEnergy(PrimitiveWithInfer): |
|
|
|
``GPU`` |
|
|
|
Examples: |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, bond_numbers): |
|
|
|
self.bond_numbers = bond_numbers |
|
|
|
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'], |
|
|
|
outputs=['atom_ene']) |
|
|
|
self.add_prim_attr('bond_numbers', self.bond_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) |
|
|
|
return [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler_f_type', scaler_f_type, [mstype.float32], self.name) |
|
|
|
@@ -167,13 +228,28 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'], |
|
|
|
outputs=['frc_f', 'atom_e']) |
|
|
|
self.add_prim_attr('bond_numbers', self.bond_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) |
|
|
|
return uint_crd_f_shape, [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler_f_type', scaler_f_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid('atom_a_type', atom_a_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('atom_b_type', atom_b_type, [mstype.int32], self.name) |
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid('bond_k_type', bond_k_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('bond_r0_type', bond_r0_type, [mstype.float32], self.name) |
|
|
|
return bond_r0_type, bond_r0_type |
|
|
|
@@ -213,17 +289,33 @@ class BondForceWithAtomVirial(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'bond_k', 'bond_r0'], |
|
|
|
outputs=['frc_f', 'atom_v']) |
|
|
|
self.add_prim_attr('bond_numbers', self.bond_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, bond_k_shape, bond_r0_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name) |
|
|
|
return uint_crd_f_shape, [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler_f_type', scaler_f_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid('atom_a_type', atom_a_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('atom_b_type', atom_b_type, [mstype.int32], self.name) |
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid('bond_k_type', bond_k_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('bond_r0_type', bond_r0_type, [mstype.float32], self.name) |
|
|
|
return bond_r0_type, bond_r0_type |
|
|
|
|
|
|
|
|
|
|
|
class DihedralForce(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
DihedralForce: |
|
|
|
@@ -259,18 +351,18 @@ class DihedralForce(PrimitiveWithInfer): |
|
|
|
Inputs: |
|
|
|
- **uint_crd_f** (Tensor, uint32) - [N, 3], the unsigned int coordinates |
|
|
|
value of each atom. |
|
|
|
- **scalar_f** (Tensor, float32) - [3, 1], the 3-D scale factor between |
|
|
|
- **scalar_f** (Tensor, float32) - [3, ], the 3-D scale factor between |
|
|
|
the real space float coordinates and the unsigned int coordinates. |
|
|
|
- **atom_a** (Tensor, int32) - [M, 1], the 1st atom index of each dihedral. |
|
|
|
- **atom_b** (Tensor, int32) - [M, 1], the 2nd atom index of each dihedral. |
|
|
|
- **atom_c** (Tensor, int32) - [M, 1], the 3rd atom index of each dihedral. |
|
|
|
- **atom_d** (Tensor, int32) - [M, 1], the 4th atom index of each dihedral. |
|
|
|
- **atom_a** (Tensor, int32) - [M, ], the 1st atom index of each dihedral. |
|
|
|
- **atom_b** (Tensor, int32) - [M, ], the 2nd atom index of each dihedral. |
|
|
|
- **atom_c** (Tensor, int32) - [M, ], the 3rd atom index of each dihedral. |
|
|
|
- **atom_d** (Tensor, int32) - [M, ], the 4th atom index of each dihedral. |
|
|
|
4 atoms are connected in the form a-b-c-d. |
|
|
|
- **ipn** (Tensor, int32) - [M, 1], the period of dihedral angle of each dihedral. |
|
|
|
- **pk** (Tensor, float32) - [M, 1], the force constant of each dihedral. |
|
|
|
- **gamc** (Tensor, float32) - [M, 1], k*cos(phi_0) of each dihedral. |
|
|
|
- **gams** (Tensor, float32) - [M, 1], k*sin(phi_0) of each dihedral. |
|
|
|
- **pn** (Tensor, float32) - [M, 1], the floating point form of ipn. |
|
|
|
- **ipn** (Tensor, int32) - [M, ], the period of dihedral angle of each dihedral. |
|
|
|
- **pk** (Tensor, float32) - [M, ], the force constant of each dihedral. |
|
|
|
- **gamc** (Tensor, float32) - [M, ], k*cos(phi_0) of each dihedral. |
|
|
|
- **gams** (Tensor, float32) - [M, ], k*sin(phi_0) of each dihedral. |
|
|
|
- **pn** (Tensor, float32) - [M, ], the floating point form of ipn. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **frc_f** (Tensor, float32) - [N, 3], the force felt by each atom. |
|
|
|
@@ -289,6 +381,29 @@ class DihedralForce(PrimitiveWithInfer): |
|
|
|
outputs=['frc_f']) |
|
|
|
self.add_prim_attr('dihedral_numbers', self.dihedral_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape, |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): |
|
|
|
cls_name = self.name |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) |
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) |
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) |
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) |
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) |
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) |
|
|
|
return uint_crd_f_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, |
|
|
|
ipn_type, pk_type, gamc_type, gams_type, pn_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -302,7 +417,6 @@ class DihedralForce(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('gamc_type', gamc_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('gams_type', gams_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('pn_type', pn_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return pn_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -321,7 +435,7 @@ class DihedralEnergy(PrimitiveWithInfer): |
|
|
|
Same as operator DihedralForce(). |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **ene** (Tensor, float32) - [M, 1], the potential energy for each |
|
|
|
- **ene** (Tensor, float32) - [M, ], the potential energy for each |
|
|
|
dihedral term. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -338,6 +452,29 @@ class DihedralEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['ene']) |
|
|
|
self.add_prim_attr('dihedral_numbers', self.dihedral_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape, |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): |
|
|
|
cls_name = self.name |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) |
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) |
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) |
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) |
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) |
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) |
|
|
|
return [M,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, |
|
|
|
ipn_type, pk_type, gamc_type, gams_type, pn_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -351,7 +488,6 @@ class DihedralEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('gamc_type', gamc_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('gams_type', gams_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('pn_type', pn_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return pn_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -368,7 +504,7 @@ class DihedralAtomEnergy(PrimitiveWithInfer): |
|
|
|
Same as operator DihedralEnergy(). |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **ene** (Tensor, float32) - [N, 1], the accumulated potential |
|
|
|
- **ene** (Tensor, float32) - [N, ], the accumulated potential |
|
|
|
energy for each atom. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -385,6 +521,30 @@ class DihedralAtomEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['ene']) |
|
|
|
self.add_prim_attr('dihedral_numbers', self.dihedral_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape, |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) |
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) |
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) |
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) |
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) |
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) |
|
|
|
return [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, |
|
|
|
ipn_type, pk_type, gamc_type, gams_type, pn_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -398,7 +558,6 @@ class DihedralAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('gamc_type', gamc_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('gams_type', gams_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('pn_type', pn_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return pn_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -415,7 +574,7 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **frc_f** (Tensor, float32) - [N, 3], same as operator DihedralForce(). |
|
|
|
- **ene** (Tensor, float32) - [N, 1], same as operator DihedralAtomEnergy(). |
|
|
|
- **ene** (Tensor, float32) - [N, ], same as operator DihedralAtomEnergy(). |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``GPU`` |
|
|
|
@@ -431,6 +590,30 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['frc_f', 'ene']) |
|
|
|
self.add_prim_attr('dihedral_numbers', self.dihedral_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, atom_d_shape, |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name) |
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name) |
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name) |
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name) |
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name) |
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name) |
|
|
|
return uint_crd_f_shape, [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type, |
|
|
|
ipn_type, pk_type, gamc_type, gams_type, pn_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -444,7 +627,6 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('gamc_type', gamc_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('gams_type', gams_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('pn_type', pn_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return pn_type, pn_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -470,14 +652,14 @@ class AngleForce(PrimitiveWithInfer): |
|
|
|
Inputs: |
|
|
|
- **uint_crd_f** (Tensor, uint32) - [N, 3], the unsigned int coordinate |
|
|
|
value of each atom. |
|
|
|
- **scaler_f** (Tensor, float32) - [3, 1], the 3-D scale factor between |
|
|
|
- **scaler_f** (Tensor, float32) - [3, ], the 3-D scale factor between |
|
|
|
the real space float coordinates and the unsigned int coordinates. |
|
|
|
- **atom_a** (Tensor, int32) - [M, 1], the 1st atom index of each angle. |
|
|
|
- **atom_b** (Tensor, int32) - [M, 1], the 2nd and the central atom index |
|
|
|
- **atom_a** (Tensor, int32) - [M, ], the 1st atom index of each angle. |
|
|
|
- **atom_b** (Tensor, int32) - [M, ], the 2nd and the central atom index |
|
|
|
of each angle. |
|
|
|
- **atom_c** (Tensor, int32) - [M, 1], the 3rd atom index of each angle. |
|
|
|
- **angle_k** (Tensor, float32) - [M, 1], the force constant for each angle. |
|
|
|
- **angle_theta0** (Tensor, float32) - [M, 1], the equilibrium position value |
|
|
|
- **atom_c** (Tensor, int32) - [M, ], the 3rd atom index of each angle. |
|
|
|
- **angle_k** (Tensor, float32) - [M, ], the force constant for each angle. |
|
|
|
- **angle_theta0** (Tensor, float32) - [M, ], the equilibrium position value |
|
|
|
for each angle. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
@@ -497,6 +679,26 @@ class AngleForce(PrimitiveWithInfer): |
|
|
|
outputs=['frc_f']) |
|
|
|
self.add_prim_attr('angle_numbers', self.angle_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape, |
|
|
|
angle_theta0_shape): |
|
|
|
cls_name = self.name |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) |
|
|
|
return uint_crd_f_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, |
|
|
|
angle_theta0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -526,7 +728,7 @@ class AngleEnergy(PrimitiveWithInfer): |
|
|
|
Same as operator AngleForce(). |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **ene** (Tensor, float32) - [M, 1], the potential energy for |
|
|
|
- **ene** (Tensor, float32) - [M, ], the potential energy for |
|
|
|
each angle term. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -543,6 +745,26 @@ class AngleEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['ene']) |
|
|
|
self.add_prim_attr('angle_numbers', self.angle_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape, |
|
|
|
angle_theta0_shape): |
|
|
|
cls_name = self.name |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) |
|
|
|
return [M,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, |
|
|
|
angle_theta0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -568,7 +790,7 @@ class AngleAtomEnergy(PrimitiveWithInfer): |
|
|
|
Same as operator AngleForce(). |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **ene** (Tensor, float32) - [N, 1], the accumulated potential energy |
|
|
|
- **ene** (Tensor, float32) - [N, ], the accumulated potential energy |
|
|
|
for each atom. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
@@ -585,6 +807,27 @@ class AngleAtomEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['ene']) |
|
|
|
self.add_prim_attr('angle_numbers', self.angle_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape, |
|
|
|
angle_theta0_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) |
|
|
|
return [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, |
|
|
|
angle_theta0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
@@ -610,7 +853,7 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **frc_f** (Tensor, float32) - [N, 3], same as operator AngleForce(). |
|
|
|
- **ene** (Tensor, float) - [N, 1], same as operator AngleAtomEnergy(). |
|
|
|
- **ene** (Tensor, float) - [N, ], same as operator AngleAtomEnergy(). |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``GPU`` |
|
|
|
@@ -626,6 +869,27 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
outputs=['frc_f', 'ene']) |
|
|
|
self.add_prim_attr('angle_numbers', self.angle_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, scaler_f_shape, atom_a_shape, atom_b_shape, atom_c_shape, angle_k_shape, |
|
|
|
angle_theta0_shape): |
|
|
|
cls_name = self.name |
|
|
|
N = uint_crd_f_shape[0] |
|
|
|
M = atom_a_shape[0] |
|
|
|
validator.check_int( |
|
|
|
uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name) |
|
|
|
validator.check_int( |
|
|
|
angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name) |
|
|
|
return uint_crd_f_shape, [N,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type, |
|
|
|
angle_theta0_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f_dtype', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
|