|
|
|
@@ -69,13 +69,19 @@ class BondForce(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.bond_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
|
|
|
return uint_crd_f_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
|
|
|
|
@@ -136,13 +142,19 @@ class BondEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.bond_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
|
|
|
|
|
|
|
return bond_k_shape
|
|
|
|
|
|
|
|
@@ -198,13 +210,19 @@ class BondAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.bond_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
|
|
|
return [N,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, bond_k_type, bond_r0_type):
|
|
|
|
@@ -259,13 +277,19 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.bond_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
|
|
|
|
|
|
|
return uint_crd_f_shape, [N,]
|
|
|
|
|
|
|
|
@@ -333,13 +357,19 @@ class BondForceWithAtomVirial(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.bond_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_k_shape), 1, Rel.EQ, "bond_k_dim", cls_name)
|
|
|
|
validator.check_int(len(bond_r0_shape), 1, Rel.EQ, "bond_r0_dim", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "uint_crd_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(bond_k_shape[0], M, Rel.EQ, "bond_k_shape", cls_name)
|
|
|
|
validator.check_int(bond_r0_shape[0], M, Rel.EQ, "bond_r0_shape", cls_name)
|
|
|
|
|
|
|
|
return uint_crd_f_shape, [N,]
|
|
|
|
|
|
|
|
@@ -436,17 +466,29 @@ class DihedralForce(PrimitiveWithInfer): |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
M = self.dihedral_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
|
|
|
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
|
|
|
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
|
|
|
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
|
|
|
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
|
|
|
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
|
|
|
return uint_crd_f_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
|
|
|
@@ -516,17 +558,29 @@ class DihedralEnergy(PrimitiveWithInfer): |
|
|
|
ipn_shape, pk_shape, gamc_shape, gams_shape, pn_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
M = self.dihedral_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
|
|
|
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
|
|
|
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
|
|
|
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
|
|
|
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
|
|
|
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
|
|
|
return [M,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
|
|
|
@@ -595,17 +649,29 @@ class DihedralAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = uint_crd_f_shape[0]
|
|
|
|
M = self.dihedral_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
|
|
|
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
|
|
|
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
|
|
|
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
|
|
|
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
|
|
|
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
|
|
|
return [N,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
|
|
|
@@ -673,17 +739,29 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = uint_crd_f_shape[0]
|
|
|
|
M = self.dihedral_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_d_shape), 1, Rel.EQ, "atom_d_dim", cls_name)
|
|
|
|
validator.check_int(len(ipn_shape), 1, Rel.EQ, "ipn_dim", cls_name)
|
|
|
|
validator.check_int(len(pk_shape), 1, Rel.EQ, "pk_dim", cls_name)
|
|
|
|
validator.check_int(len(gamc_shape), 1, Rel.EQ, "gamc_dim", cls_name)
|
|
|
|
validator.check_int(len(gams_shape), 1, Rel.EQ, "gams_dim", cls_name)
|
|
|
|
validator.check_int(len(pn_shape), 1, Rel.EQ, "pn_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(atom_d_shape[0], M, Rel.EQ, "atom_d_shape", cls_name)
|
|
|
|
validator.check_int(ipn_shape[0], M, Rel.EQ, "ipn_shape", cls_name)
|
|
|
|
validator.check_int(pk_shape[0], M, Rel.EQ, "pk_shape", cls_name)
|
|
|
|
validator.check_int(gamc_shape[0], M, Rel.EQ, "gamc_shape", cls_name)
|
|
|
|
validator.check_int(gams_shape[0], M, Rel.EQ, "gams_shape", cls_name)
|
|
|
|
validator.check_int(pn_shape[0], M, Rel.EQ, "pn_shape", cls_name)
|
|
|
|
return uint_crd_f_shape, [N,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, atom_d_type,
|
|
|
|
@@ -757,13 +835,21 @@ class AngleForce(PrimitiveWithInfer): |
|
|
|
angle_theta0_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
M = self.angle_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
|
|
|
return uint_crd_f_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
|
|
|
@@ -825,13 +911,21 @@ class AngleEnergy(PrimitiveWithInfer): |
|
|
|
angle_theta0_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
M = self.angle_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
|
|
|
return [M,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
|
|
|
@@ -888,13 +982,21 @@ class AngleAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = uint_crd_f_shape[0]
|
|
|
|
M = self.angle_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
|
|
|
return [N,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
|
|
|
@@ -951,13 +1053,21 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = uint_crd_f_shape[0]
|
|
|
|
M = self.angle_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler_f_shape), 1, Rel.EQ, "scaler_f_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_a_shape), 1, Rel.EQ, "atom_a_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_b_shape), 1, Rel.EQ, "atom_b_dim", cls_name)
|
|
|
|
validator.check_int(len(atom_c_shape), 1, Rel.EQ, "atom_c_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_k_shape), 1, Rel.EQ, "angle_k_dim", cls_name)
|
|
|
|
validator.check_int(len(angle_theta0_shape), 1, Rel.EQ, "angle_theta0_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler_f_shape[0], 3, Rel.EQ, "scaler_f_shape", cls_name)
|
|
|
|
validator.check_int(atom_a_shape[0], M, Rel.EQ, "atom_a_shape", cls_name)
|
|
|
|
validator.check_int(atom_b_shape[0], M, Rel.EQ, "atom_b_shape", cls_name)
|
|
|
|
validator.check_int(atom_c_shape[0], M, Rel.EQ, "atom_c_shape", cls_name)
|
|
|
|
validator.check_int(angle_k_shape[0], M, Rel.EQ, "angle_k_shape", cls_name)
|
|
|
|
validator.check_int(angle_theta0_shape[0], M, Rel.EQ, "angle_theta0_shape", cls_name)
|
|
|
|
return uint_crd_f_shape, [N,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, scaler_f_type, atom_a_type, atom_b_type, atom_c_type, angle_k_type,
|
|
|
|
@@ -1029,16 +1139,21 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
Q = LJ_type_A_shape[0]
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
|
|
|
return uint_crd_f_shape
|
|
|
|
|
|
|
|
@@ -1112,16 +1227,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
Q = LJ_type_A_shape[0]
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
|
|
|
return [self.dihedral_14_numbers,]
|
|
|
|
|
|
|
|
@@ -1201,16 +1321,22 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
Q = LJ_type_A_shape[0]
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
|
|
|
return [self.atom_numbers, 3]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
|
|
|
@@ -1286,18 +1412,23 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
Q = LJ_type_A_shape[0]
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
|
|
|
return uint_crd_f_shape, charge_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
|
|
|
@@ -1366,17 +1497,22 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
Q = LJ_type_A_shape[0]
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
|
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
|
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name)
|
|
|
|
return LJtype_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
|
|
|
@@ -1445,15 +1581,19 @@ class Dihedral14CFEnergy(PrimitiveWithInfer): |
|
|
|
cf_scale_factor_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
|
|
|
return [self.dihedral_14_numbers,]
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
|
|
|
@@ -1516,15 +1656,19 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer): |
|
|
|
cf_scale_factor_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
M = self.dihedral_14_numbers
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], 3, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14", cls_name)
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14", cls_name)
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_f_shape), 1, Rel.EQ, "boxlength_f_dim", cls_name)
|
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
|
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
|
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
|
|
|
|
return LJtype_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
|
|
|
|
@@ -1672,10 +1816,14 @@ class PMEReciprocalForce(PrimitiveWithInfer): |
|
|
|
def infer_shape(self, boxlength_shape, uint_crd_shape, charge_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(boxlength_shape), 1, Rel.EQ, "boxlength_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(boxlength_shape[0], 3, Rel.EQ, "boxlength_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
return uint_crd_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, boxlength_type, uint_crd_type, charge_type):
|
|
|
|
@@ -1693,6 +1841,7 @@ class PMEExcludedForce(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args:
|
|
|
|
atom_numbers(int32): the number of atoms, N.
|
|
|
|
excluded_numbers(int32): the length of excluded list, E.
|
|
|
|
beta(float32): the PME beta parameter, determined by the
|
|
|
|
non-bond cutoff value and simulation precision tolerance.
|
|
|
|
|
|
|
|
@@ -1716,27 +1865,36 @@ class PMEExcludedForce(PrimitiveWithInfer): |
|
|
|
"""
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
def __init__(self, atom_numbers, beta):
|
|
|
|
def __init__(self, atom_numbers, excluded_numbers, beta):
|
|
|
|
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
|
|
|
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
|
|
|
validator.check_value_type('beta', beta, (float), self.name)
|
|
|
|
self.atom_numbers = atom_numbers
|
|
|
|
self.excluded_numbers = excluded_numbers
|
|
|
|
self.beta = beta
|
|
|
|
self.init_prim_io_names(
|
|
|
|
inputs=['uint_crd', 'sacler', 'charge', 'excluded_list_start', 'excluded_list', 'excluded_atom_numbers'],
|
|
|
|
outputs=['force'])
|
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
|
|
|
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
|
|
|
|
self.add_prim_attr('beta', self.beta)
|
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_shape, sacler_shape, charge_shape, excluded_list_start_shape, excluded_list_shape,
|
|
|
|
excluded_atom_numbers_shape):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start", cls_name)
|
|
|
|
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
|
|
|
|
validator.check_int(len(uint_crd_shape), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(sacler_shape), 1, Rel.EQ, "sacler_dim", cls_name)
|
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(excluded_list_start_shape), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
|
|
|
|
validator.check_int(len(excluded_atom_numbers_shape), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd_shape[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(sacler_shape[0], 3, Rel.EQ, "sacler_shape", cls_name)
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(excluded_list_start_shape[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
|
|
|
|
validator.check_int(excluded_atom_numbers_shape[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
|
|
|
|
return uint_crd_shape
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_type, sacler_type, charge_type, excluded_list_start_type, excluded_list_type,
|
|
|
|
@@ -1763,6 +1921,7 @@ class PMEEnergy(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args:
|
|
|
|
atom_numbers(int32): the number of atoms, N.
|
|
|
|
excluded_numbers(int32): the length of excluded list, E.
|
|
|
|
beta(float32): the PME beta parameter, determined by the
|
|
|
|
non-bond cutoff value and simulation precision tolerance.
|
|
|
|
fftx(int32): the number of points for Fourier transform in dimension X.
|
|
|
|
@@ -1795,13 +1954,15 @@ class PMEEnergy(PrimitiveWithInfer): |
|
|
|
"""
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
def __init__(self, atom_numbers, beta, fftx, ffty, fftz):
|
|
|
|
def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz):
|
|
|
|
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
|
|
|
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
|
|
|
validator.check_value_type('beta', beta, (float), self.name)
|
|
|
|
validator.check_value_type('fftx', fftx, (int), self.name)
|
|
|
|
validator.check_value_type('ffty', ffty, (int), self.name)
|
|
|
|
validator.check_value_type('fftz', fftz, (int), self.name)
|
|
|
|
self.atom_numbers = atom_numbers
|
|
|
|
self.excluded_numbers = excluded_numbers
|
|
|
|
self.beta = beta
|
|
|
|
self.fftx = fftx
|
|
|
|
self.ffty = ffty
|
|
|
|
@@ -1811,6 +1972,7 @@ class PMEEnergy(PrimitiveWithInfer): |
|
|
|
'excluded_list', 'excluded_atom_numbers'],
|
|
|
|
outputs=['reciprocal_ene', 'self_ene', 'direct_ene', 'correction_ene'])
|
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
|
|
|
self.add_prim_attr('excluded_numbers', self.excluded_numbers)
|
|
|
|
self.add_prim_attr('beta', self.beta)
|
|
|
|
self.add_prim_attr('fftx', self.fftx)
|
|
|
|
self.add_prim_attr('ffty', self.ffty)
|
|
|
|
@@ -1820,16 +1982,25 @@ class PMEEnergy(PrimitiveWithInfer): |
|
|
|
excluded_list, excluded_atom_numbers):
|
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(box_length[0], 3, Rel.EQ, "box_length", cls_name)
|
|
|
|
validator.check_int(charge[0], N, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
|
|
|
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start", cls_name)
|
|
|
|
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers", cls_name)
|
|
|
|
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list", cls_name)
|
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(box_length), 1, Rel.EQ, "sacler_dim", cls_name)
|
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_serial), 2, Rel.LE, "nl_serial_dim", cls_name)
|
|
|
|
validator.check_int(len(excluded_list_start), 1, Rel.EQ, "excluded_list_start_dim", cls_name)
|
|
|
|
validator.check_int(len(excluded_atom_numbers), 1, Rel.EQ, "excluded_atom_numbers_dim", cls_name)
|
|
|
|
validator.check_int(len(excluded_list), 1, Rel.GE, "excluded_list", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(box_length[0], 3, Rel.EQ, "box_length_shape", cls_name)
|
|
|
|
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.LE, "nl_serial_shape[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
|
|
|
validator.check_int(excluded_list_start[0], N, Rel.EQ, "excluded_list_start_shape", cls_name)
|
|
|
|
validator.check_int(excluded_atom_numbers[0], N, Rel.EQ, "excluded_atom_numbers_shape", cls_name)
|
|
|
|
validator.check_int(excluded_list[0], 0, Rel.GE, "excluded_list_shape", cls_name)
|
|
|
|
return (1,), (1,), (1,), (1,)
|
|
|
|
|
|
|
|
def infer_dtype(self, box_length, uint_crd, charge, nl_numbers, nl_serial, scaler, excluded_list_start,
|
|
|
|
@@ -1906,16 +2077,25 @@ class LJEnergy(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
Q = d_LJ_A[0]
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
|
|
|
return charge
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
|
|
|
@@ -1983,16 +2163,25 @@ class LJForce(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
Q = d_LJ_A[0]
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
|
|
|
return uint_crd
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
|
|
|
@@ -2058,16 +2247,25 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer): |
|
|
|
cls_name = self.name
|
|
|
|
N = self.atom_numbers
|
|
|
|
Q = d_LJ_A[0]
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype", cls_name)
|
|
|
|
validator.check_int(charge[0], 3, Rel.EQ, "charge", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B[0]", cls_name)
|
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
|
|
|
|
validator.check_int(len(LJtype), 1, Rel.EQ, "LJtype_dim", cls_name)
|
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
|
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
|
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
|
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
|
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], N, Rel.EQ, "uint_crd_shape[0]", cls_name)
|
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
|
|
|
|
validator.check_int(LJtype[0], N, Rel.EQ, "LJtype_shape", cls_name)
|
|
|
|
validator.check_int(charge[0], N, Rel.EQ, "charge_shape", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(nl_numbers[0], N, Rel.EQ, "nl_numbers_shape", cls_name)
|
|
|
|
validator.check_int(nl_serial[0], N, Rel.EQ, "nl_serial_shape[0]", cls_name)
|
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
|
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
|
|
|
|
validator.check_int(d_LJ_B[0], Q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
|
|
|
|
return uint_crd
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, LJtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
|
|
|
|
|