|
|
|
@@ -1138,11 +1138,11 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape, |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
|
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
m = self.dihedral_14_numbers |
|
|
|
q = LJ_type_A_shape[0] |
|
|
|
q = lj_type_a_shape[0] |
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -1150,21 +1150,21 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) |
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) |
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name) |
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) |
|
|
|
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name) |
|
|
|
validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name) |
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
|
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
return uint_crd_f_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
|
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type): |
|
|
|
lj_scale_factor_type, lj_type_a_type, lj_type_b_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name) |
|
|
|
@@ -1174,9 +1174,9 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name) |
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name) |
|
|
|
return LJ_type_B_type |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name) |
|
|
|
return lj_type_b_type |
|
|
|
|
|
|
|
|
|
|
|
class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
|
@@ -1230,11 +1230,11 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape, |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
|
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
m = self.dihedral_14_numbers |
|
|
|
q = LJ_type_A_shape[0] |
|
|
|
q = lj_type_a_shape[0] |
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -1242,21 +1242,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) |
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) |
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name) |
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name) |
|
|
|
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name) |
|
|
|
validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name) |
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
|
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
return [self.dihedral_14_numbers,] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
|
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type): |
|
|
|
lj_scale_factor_type, lj_type_a_type, lj_type_b_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name) |
|
|
|
@@ -1264,10 +1264,10 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return LJ_type_A_type |
|
|
|
return lj_type_a_type |
|
|
|
|
|
|
|
|
|
|
|
class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
@@ -1326,11 +1326,11 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape, |
|
|
|
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
|
lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
m = self.dihedral_14_numbers |
|
|
|
q = LJ_type_A_shape[0] |
|
|
|
q = lj_type_a_shape[0] |
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -1339,14 +1339,14 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) |
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) |
|
|
|
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name) |
|
|
|
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name) |
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
|
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
@@ -1354,7 +1354,7 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
return [self.atom_numbers, 3] |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
|
lj_scale_factor_type, cf_scale_factor_type, LJ_type_A_type, LJ_type_B_type): |
|
|
|
lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name) |
|
|
|
@@ -1363,10 +1363,10 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return LJ_type_A_type |
|
|
|
return lj_type_a_type |
|
|
|
|
|
|
|
|
|
|
|
class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
@@ -1423,11 +1423,11 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape, |
|
|
|
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
|
lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
m = self.dihedral_14_numbers |
|
|
|
q = LJ_type_A_shape[0] |
|
|
|
q = lj_type_a_shape[0] |
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -1436,14 +1436,14 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) |
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) |
|
|
|
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name) |
|
|
|
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name) |
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
|
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
@@ -1451,7 +1451,7 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
return uint_crd_f_shape, charge_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
|
lj_scale_factor_type, cf_scale_factor_type, LJ_type_A_type, LJ_type_B_type): |
|
|
|
lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name) |
|
|
|
@@ -1460,8 +1460,8 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return charge_dtype, charge_dtype |
|
|
|
|
|
|
|
@@ -1513,10 +1513,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape, |
|
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
|
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
q = LJ_type_A_shape[0] |
|
|
|
q = lj_type_a_shape[0] |
|
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -1524,14 +1524,14 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name) |
|
|
|
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name) |
|
|
|
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name) |
|
|
|
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) |
|
|
|
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name) |
|
|
|
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name) |
|
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
|
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
m = self.dihedral_14_numbers |
|
|
|
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
@@ -1539,7 +1539,7 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
|
return ljtype_shape |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
|
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type): |
|
|
|
lj_scale_factor_type, lj_type_a_type, lj_type_b_type): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name) |
|
|
|
@@ -1548,10 +1548,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], |
|
|
|
self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name) |
|
|
|
|
|
|
|
return LJ_type_A_type |
|
|
|
return lj_type_a_type |
|
|
|
|
|
|
|
|
|
|
|
class Dihedral14CFEnergy(PrimitiveWithInfer): |
|
|
|
@@ -2123,10 +2123,10 @@ class LJEnergy(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
self.add_prim_attr('cutoff_square', self.cutoff_square) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
q = d_LJ_A[0] |
|
|
|
q = d_lj_a[0] |
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -2134,7 +2134,7 @@ class LJEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) |
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) |
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) |
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) |
|
|
|
@@ -2145,18 +2145,18 @@ class LJEnergy(PrimitiveWithInfer): |
|
|
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name) |
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) |
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) |
|
|
|
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
return charge |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name) |
|
|
|
return charge |
|
|
|
|
|
|
|
|
|
|
|
@@ -2209,10 +2209,10 @@ class LJForce(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('atom_numbers', self.atom_numbers) |
|
|
|
self.add_prim_attr('cutoff_square', self.cutoff_square) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
q = d_LJ_A[0] |
|
|
|
q = d_lj_a[0] |
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -2220,7 +2220,7 @@ class LJForce(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) |
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) |
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) |
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) |
|
|
|
@@ -2231,18 +2231,18 @@ class LJForce(PrimitiveWithInfer): |
|
|
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name) |
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) |
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) |
|
|
|
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
return uint_crd |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name) |
|
|
|
return charge |
|
|
|
|
|
|
|
|
|
|
|
@@ -2293,10 +2293,10 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('cutoff', self.cutoff) |
|
|
|
self.add_prim_attr('pme_beta', self.pme_beta) |
|
|
|
|
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
cls_name = self.name |
|
|
|
n = self.atom_numbers |
|
|
|
q = d_LJ_A[0] |
|
|
|
q = d_lj_a[0] |
|
|
|
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name) |
|
|
|
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
|
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name) |
|
|
|
@@ -2304,7 +2304,7 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name) |
|
|
|
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name) |
|
|
|
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name) |
|
|
|
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name) |
|
|
|
|
|
|
|
validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name) |
|
|
|
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name) |
|
|
|
@@ -2315,18 +2315,18 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer): |
|
|
|
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name) |
|
|
|
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name) |
|
|
|
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name) |
|
|
|
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name) |
|
|
|
return uint_crd |
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B): |
|
|
|
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b): |
|
|
|
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name) |
|
|
|
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name) |
|
|
|
return charge |
|
|
|
|
|
|
|
|
|
|
|
|