|
|
@@ -1141,6 +1141,7 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
cls_name = self.name |
|
|
cls_name = self.name |
|
|
N = self.atom_numbers |
|
|
N = self.atom_numbers |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
Q = LJ_type_A_shape[0] |
|
|
Q = LJ_type_A_shape[0] |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
@@ -1157,6 +1158,9 @@ class Dihedral14LJForce(PrimitiveWithInfer): |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
return uint_crd_f_shape |
|
|
return uint_crd_f_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1229,6 +1233,7 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
cls_name = self.name |
|
|
cls_name = self.name |
|
|
N = self.atom_numbers |
|
|
N = self.atom_numbers |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
Q = LJ_type_A_shape[0] |
|
|
Q = LJ_type_A_shape[0] |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
@@ -1245,6 +1250,9 @@ class Dihedral14LJEnergy(PrimitiveWithInfer): |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B", cls_name) |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
return [self.dihedral_14_numbers,] |
|
|
return [self.dihedral_14_numbers,] |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1336,9 +1344,13 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer): |
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) |
|
|
validator.check_int(uint_crd_f_shape[0], N, Rel.EQ, "uint_crd_f_shape[0]", cls_name) |
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) |
|
|
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name) |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], M, Rel.EQ, "charge_shape", cls_name) |
|
|
|
|
|
|
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name) |
|
|
return [self.atom_numbers, 3] |
|
|
return [self.atom_numbers, 3] |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1414,6 +1426,7 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape): |
|
|
cls_name = self.name |
|
|
cls_name = self.name |
|
|
N = self.atom_numbers |
|
|
N = self.atom_numbers |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
Q = LJ_type_A_shape[0] |
|
|
Q = LJ_type_A_shape[0] |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
validator.check_int(len(LJtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name) |
|
|
@@ -1431,6 +1444,10 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer): |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name) |
|
|
return uint_crd_f_shape, charge_shape |
|
|
return uint_crd_f_shape, charge_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1515,6 +1532,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer): |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
validator.check_int(LJ_type_B_shape[0], Q, Rel.EQ, "LJ_type_B_shape", cls_name) |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(lj_scale_factor_shape[0], M, Rel.EQ, "lj_scale_factor_shape", cls_name) |
|
|
return LJtype_shape |
|
|
return LJtype_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1596,6 +1617,10 @@ class Dihedral14CFEnergy(PrimitiveWithInfer): |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name) |
|
|
return [self.dihedral_14_numbers,] |
|
|
return [self.dihedral_14_numbers,] |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -1671,6 +1696,10 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer): |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(LJtype_shape[0], N, Rel.EQ, "LJtype_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(charge_shape[0], N, Rel.EQ, "charge_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name) |
|
|
|
|
|
M = self.dihedral_14_numbers |
|
|
|
|
|
validator.check_int(a_14_shape[0], M, Rel.EQ, "a_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(b_14_shape[0], M, Rel.EQ, "b_14_shape", cls_name) |
|
|
|
|
|
validator.check_int(cf_scale_factor_shape[0], M, Rel.EQ, "cf_scale_factor_shape", cls_name) |
|
|
return LJtype_shape |
|
|
return LJtype_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
def infer_dtype(self, uint_crd_f_dtype, LJtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type, |
|
|
@@ -2674,14 +2703,15 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer): |
|
|
scheme for efficient configurational sampling for classical/quantum canonical |
|
|
scheme for efficient configurational sampling for classical/quantum canonical |
|
|
ensembles via molecular dynamics. DOI: 10.1063/1.4991621. |
|
|
ensembles via molecular dynamics. DOI: 10.1063/1.4991621. |
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **atom_numbers** (int32) - the number of atoms N. |
|
|
|
|
|
- **dt** (float32) - time step for finite difference. |
|
|
|
|
|
- **half_dt** (float32) - half of time step for finite difference. |
|
|
|
|
|
- **exp_gamma** (float32) - parameter in Liu's dynamic, equals |
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
atom_numbers(int32): the number of atoms N. |
|
|
|
|
|
dt(float32): time step for finite difference. |
|
|
|
|
|
half_dt(float32): half of time step for finite difference. |
|
|
|
|
|
exp_gamma(float32): parameter in Liu's dynamic, equals |
|
|
exp(-gamma_ln * dt), where gamma_ln is the firction factor in Langvin |
|
|
exp(-gamma_ln * dt), where gamma_ln is the firction factor in Langvin |
|
|
dynamics. |
|
|
dynamics. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
- **inverse_mass** (Tensor, float32) - [N,], the inverse value of |
|
|
- **inverse_mass** (Tensor, float32) - [N,], the inverse value of |
|
|
mass of each atom. |
|
|
mass of each atom. |
|
|
- **sqrt_mass_inverse** (Tensor, float32) - [N,], the inverse square root value |
|
|
- **sqrt_mass_inverse** (Tensor, float32) - [N,], the inverse square root value |
|
|
@@ -2699,7 +2729,6 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
Supported Platforms: |
|
|
``GPU`` |
|
|
``GPU`` |
|
|
Examples: |
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
@@ -2735,14 +2764,14 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer): |
|
|
validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name) |
|
|
validator.check_tensor_dtype_valid('rand_frc', rand_frc, [mstype.float32], self.name) |
|
|
return mstype.float32 |
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrdToUintCrd(PrimitiveWithInfer): |
|
|
class CrdToUintCrd(PrimitiveWithInfer): |
|
|
""" |
|
|
""" |
|
|
Convert FP32 coordinate to Uint32 coordinate. |
|
|
Convert FP32 coordinate to Uint32 coordinate. |
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **atom_numbers** (int32) - the number of atoms N. |
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
atom_numbers(int32): the number of atoms N. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
- **crd_to_uint_crd_cof** (Tensor, float32) - [3,], the . |
|
|
- **crd_to_uint_crd_cof** (Tensor, float32) - [3,], the . |
|
|
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom. |
|
|
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom. |
|
|
|
|
|
|
|
|
@@ -2751,7 +2780,6 @@ class CrdToUintCrd(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
Supported Platforms: |
|
|
``GPU`` |
|
|
``GPU`` |
|
|
Examples: |
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
@@ -2763,7 +2791,7 @@ class CrdToUintCrd(PrimitiveWithInfer): |
|
|
outputs=['output']) |
|
|
outputs=['output']) |
|
|
|
|
|
|
|
|
def infer_shape(self, crd_to_uint_crd_cof, crd): |
|
|
def infer_shape(self, crd_to_uint_crd_cof, crd): |
|
|
validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name) |
|
|
|
|
|
|
|
|
validator.check_int(crd_to_uint_crd_cof[0], 3, Rel.EQ, "crd_to_uint_crd_cof_shape", self.name) |
|
|
validator.check_int(crd[0], self.atom_numbers, Rel.EQ, "crd[0]", self.name) |
|
|
validator.check_int(crd[0], self.atom_numbers, Rel.EQ, "crd[0]", self.name) |
|
|
validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name) |
|
|
validator.check_int(crd[1], 3, Rel.EQ, "crd[1]", self.name) |
|
|
return crd |
|
|
return crd |
|
|
@@ -2773,21 +2801,19 @@ class CrdToUintCrd(PrimitiveWithInfer): |
|
|
validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name) |
|
|
validator.check_tensor_dtype_valid('crd', crd, [mstype.float32], self.name) |
|
|
return mstype.uint32 |
|
|
return mstype.uint32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MDIterationSetupRandState(PrimitiveWithInfer): |
|
|
class MDIterationSetupRandState(PrimitiveWithInfer): |
|
|
""" |
|
|
""" |
|
|
Convert FP32 coordinate to Uint32 coordinate. |
|
|
Convert FP32 coordinate to Uint32 coordinate. |
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **atom_numbers** (int32) - the number of atoms N. |
|
|
|
|
|
- **seed** (int32) - random seed. |
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
atom_numbers(int32): the number of atoms N. |
|
|
|
|
|
seed(int32): random seed. |
|
|
|
|
|
|
|
|
Outputs: |
|
|
Outputs: |
|
|
- **output** (uint32) random state. |
|
|
- **output** (uint32) random state. |
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
Supported Platforms: |
|
|
``GPU`` |
|
|
``GPU`` |
|
|
Examples: |
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
@@ -2808,3 +2834,45 @@ class MDIterationSetupRandState(PrimitiveWithInfer): |
|
|
|
|
|
|
|
|
def infer_dtype(self): |
|
|
def infer_dtype(self): |
|
|
return mstype.float32 |
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
class TransferCrd(PrimitiveWithInfer): |
|
|
|
|
|
""" |
|
|
|
|
|
Transfer the coordinates to angular and radial. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
start_serial(int32): the index start position. |
|
|
|
|
|
end_serial(int32): the index end position. |
|
|
|
|
|
number(int32): the length of angular and radial. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **crd** (Tensor, float32) - [N, 3], the coordinate of each atom. |
|
|
|
|
|
N is the number of atoms. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
- **output** (uint32) |
|
|
|
|
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
|
|
``GPU`` |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|
|
def __init__(self, start_serial, end_serial, number): |
|
|
|
|
|
self.start_serial = start_serial |
|
|
|
|
|
self.end_serial = end_serial |
|
|
|
|
|
self.number = number |
|
|
|
|
|
self.add_prim_attr('start_serial', self.start_serial) |
|
|
|
|
|
self.add_prim_attr('end_serial', self.end_serial) |
|
|
|
|
|
self.add_prim_attr('number', self.number) |
|
|
|
|
|
self.init_prim_io_names( |
|
|
|
|
|
inputs=['crd'], |
|
|
|
|
|
outputs=['radial', 'angular']) |
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, crd_shape): |
|
|
|
|
|
validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name) |
|
|
|
|
|
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name) |
|
|
|
|
|
return [self.number,], [self.number,] |
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, crd_dtype): |
|
|
|
|
|
validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name) |
|
|
|
|
|
return mstype.float32, mstype.float32 |