Browse Source

05311 sponge_ops update_thor_gradient thor_ops.py im2col

tags/v1.3.0
zhangxinfeng3 4 years ago
parent
commit
35e95b7d19
4 changed files with 89 additions and 88 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h
  2. +25
    -25
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h
  3. +1
    -2
      mindspore/ops/operations/_thor_ops.py
  4. +60
    -60
      mindspore/ops/operations/sponge_ops.py

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h View File

@@ -187,7 +187,9 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
auto matrix_g_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);

split_dim = LongToSize(GetAttr<int64_t>(kernel_node, "split_dim"));

if (split_dim == 0) {
MS_LOG(ERROR) << "Divide by zero, split_dim can not be zero.";
}
gradient_size.batch_h = gradient_shape[0] / split_dim;
gradient_size.batch_w = gradient_shape[1] / split_dim;
if (gradient_size.batch_h * split_dim != gradient_shape[0]) {


+ 25
- 25
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h View File

@@ -35,13 +35,13 @@ class Im2ColGpuFwdKernel : public GpuKernel {
input_desc_(nullptr),
output_desc_(nullptr),
filter_desc_(nullptr),
conv_desc_(nullptr),
padded_desc_(nullptr),
conv_desc_n(nullptr),
padded_desc_n(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
old_height_(0),
old_width_(0),
pad_height_(0),
pad_width_(0),
pad_width_n(0),
pad_top_(0),
pad_left_(0),
n_(0),
@@ -67,14 +67,14 @@ class Im2ColGpuFwdKernel : public GpuKernel {
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
T *padded_addr = GetDeviceAddress<T>(workspace, 0);
CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
old_width_ + pad_width_n, pad_top_, pad_left_, pad_value_, padded_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnIm2Col(cudnn_handle_, padded_desc_, padded_addr, filter_desc_, conv_desc_, output_addr),
kernel_node_, cudnnIm2Col(cudnn_handle_, padded_desc_n, padded_addr, filter_desc_, conv_desc_n, output_addr),
"cudnnIm2ColForward failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_, output_addr),
kernel_node_, cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_n, output_addr),
"cudnnIm2ColForward failed");
}

@@ -100,10 +100,10 @@ class Im2ColGpuFwdKernel : public GpuKernel {
}
CheckTensorSize({in_shape, output_shape});
Set4DDesc(in_shape, filter_shape, output_shape);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, 1),
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_n, 1),
"cudnnSetConvGroupCount failed");
pad_height_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "pad"));
pad_width_ = pad_height_;
pad_width_n = pad_height_;
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
SetStrideAndDilation(kernel_node);
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
@@ -111,16 +111,16 @@ class Im2ColGpuFwdKernel : public GpuKernel {
} else {
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
pad_height_ = 0;
pad_width_ = 0;
pad_width_n = 0;
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2],
cudnnSetConvolution2dDescriptor(conv_desc_n, pad_height_, pad_width_n, stride_[2], stride_[3], dilation_[2],
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed");
}
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_n, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
InitSizeLists();
@@ -128,11 +128,11 @@ class Im2ColGpuFwdKernel : public GpuKernel {
}

void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_n),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(filter_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_desc_),
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_desc_n),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(output_desc_),
"cudnnDestroyTensorDescriptor failed");
@@ -147,11 +147,11 @@ class Im2ColGpuFwdKernel : public GpuKernel {
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&output_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_desc_),
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_desc_n),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&filter_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_),
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_n),
"cudnnCreateConvolutionDescriptor failed");
}

@@ -164,7 +164,7 @@ class Im2ColGpuFwdKernel : public GpuKernel {
cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast<size_t *>(&output_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast<size_t *>(&padded_size_)),
cudnnGetTensorSizeInBytes(padded_desc_n, reinterpret_cast<size_t *>(&padded_size_)),
"cudnnGetTensorSizeInBytes failed");
}
input_size_list_.push_back(input_size_);
@@ -202,23 +202,23 @@ class Im2ColGpuFwdKernel : public GpuKernel {
old_height_ = SizeToInt(in_shape[2]);
old_width_ = SizeToInt(in_shape[3]);
pad_height_ = pad_list[0] + pad_list[1];
pad_width_ = pad_list[2] + pad_list[3];
pad_width_n = pad_list[2] + pad_list[3];
pad_top_ = pad_list[0];
pad_left_ = pad_list[2];

// if use_pad_ == true, using zero padding in advance, else using the default cudnn pad.
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
if (pad_height_ % 2 == 0 && pad_width_n % 2 == 0) {
use_pad_ = false;
}

CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_,
old_height_ + pad_height_, old_width_ + pad_width_),
cudnnSetTensor4dDescriptor(padded_desc_n, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_,
old_height_ + pad_height_, old_width_ + pad_width_n),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolution2dDescriptor(
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3],
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
conv_desc_n, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2],
stride_[3], dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed");
}

@@ -269,8 +269,8 @@ class Im2ColGpuFwdKernel : public GpuKernel {
cudnnTensorDescriptor_t output_desc_;
cudnnFilterDescriptor_t filter_desc_;
cudnnConvolutionFwdAlgo_t conv_algorithm_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t padded_desc_;
cudnnConvolutionDescriptor_t conv_desc_n;
cudnnTensorDescriptor_t padded_desc_n;
std::string pad_mode_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
@@ -280,7 +280,7 @@ class Im2ColGpuFwdKernel : public GpuKernel {
int old_height_;
int old_width_;
int pad_height_;
int pad_width_;
int pad_width_n;
int pad_top_;
int pad_left_;
int n_;


+ 1
- 2
mindspore/ops/operations/_thor_ops.py View File

@@ -19,7 +19,6 @@ from ..primitive import prim_attr_register, PrimitiveWithInfer
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
from ..operations.nn_ops import _check_positive_int_or_tuple
from ..._checkparam import Rel

__all__ = ["CusBatchMatMul",
"CusCholeskyTrsm",
@@ -560,7 +559,7 @@ class UpdateThorGradient(PrimitiveWithInfer):
"""

@prim_attr_register
def __init__(self, split_dim=0):
def __init__(self, split_dim=1):
"""Initialize UpdateThorGradient"""
self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
self.split_dim = split_dim


+ 60
- 60
mindspore/ops/operations/sponge_ops.py View File

@@ -1138,11 +1138,11 @@ class Dihedral14LJForce(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)

def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
cls_name = self.name
n = self.atom_numbers
m = self.dihedral_14_numbers
q = LJ_type_A_shape[0]
q = lj_type_a_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
@@ -1150,21 +1150,21 @@ class Dihedral14LJForce(PrimitiveWithInfer):
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)

validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
return uint_crd_f_shape

def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type):
lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
@@ -1174,9 +1174,9 @@ class Dihedral14LJForce(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)

validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name)
return LJ_type_B_type
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)
return lj_type_b_type


class Dihedral14LJEnergy(PrimitiveWithInfer):
@@ -1230,11 +1230,11 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)

def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
cls_name = self.name
n = self.atom_numbers
m = self.dihedral_14_numbers
q = LJ_type_A_shape[0]
q = lj_type_a_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
@@ -1242,21 +1242,21 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)

validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f[1]", cls_name)
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype", cls_name)
validator.check_int(charge_shape[0], n, Rel.EQ, "charge", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f", cls_name)
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B", cls_name)
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
return [self.dihedral_14_numbers,]

def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type):
lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
@@ -1264,10 +1264,10 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('a_14', a_14_type, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)

return LJ_type_A_type
return lj_type_a_type


class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
@@ -1326,11 +1326,11 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)

def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
cls_name = self.name
n = self.atom_numbers
m = self.dihedral_14_numbers
q = LJ_type_A_shape[0]
q = lj_type_a_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
@@ -1339,14 +1339,14 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)

validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
@@ -1354,7 +1354,7 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
return [self.atom_numbers, 3]

def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
lj_scale_factor_type, cf_scale_factor_type, LJ_type_A_type, LJ_type_B_type):
lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
@@ -1363,10 +1363,10 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)

return LJ_type_A_type
return lj_type_a_type


class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
@@ -1423,11 +1423,11 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)

def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
lj_scale_factor_shape, cf_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
lj_scale_factor_shape, cf_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
cls_name = self.name
n = self.atom_numbers
m = self.dihedral_14_numbers
q = LJ_type_A_shape[0]
q = lj_type_a_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
@@ -1436,14 +1436,14 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(cf_scale_factor_shape), 1, Rel.EQ, "cf_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)

validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
validator.check_int(lj_scale_factor_shape[0], m, Rel.EQ, "lj_scale_factor_shape", cls_name)
@@ -1451,7 +1451,7 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
return uint_crd_f_shape, charge_shape

def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
lj_scale_factor_type, cf_scale_factor_type, LJ_type_A_type, LJ_type_B_type):
lj_scale_factor_type, cf_scale_factor_type, lj_type_a_type, lj_type_b_type):
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
@@ -1460,8 +1460,8 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('cf_scale_factor', cf_scale_factor_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)

return charge_dtype, charge_dtype

@@ -1513,10 +1513,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)

def infer_shape(self, uint_crd_f_shape, ljtype_shape, charge_shape, boxlength_f_shape, a_14_shape, b_14_shape,
lj_scale_factor_shape, LJ_type_A_shape, LJ_type_B_shape):
lj_scale_factor_shape, lj_type_a_shape, lj_type_b_shape):
cls_name = self.name
n = self.atom_numbers
q = LJ_type_A_shape[0]
q = lj_type_a_shape[0]
validator.check_int(len(uint_crd_f_shape), 2, Rel.EQ, "uint_crd_f_dim", cls_name)
validator.check_int(len(ljtype_shape), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge_shape), 1, Rel.EQ, "charge_dim", cls_name)
@@ -1524,14 +1524,14 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
validator.check_int(len(a_14_shape), 1, Rel.EQ, "a_14_dim", cls_name)
validator.check_int(len(b_14_shape), 1, Rel.EQ, "b_14_dim", cls_name)
validator.check_int(len(lj_scale_factor_shape), 1, Rel.EQ, "lj_scale_factor_dim", cls_name)
validator.check_int(len(LJ_type_B_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)
validator.check_int(len(lj_type_b_shape), 1, Rel.EQ, "LJ_type_B_dim", cls_name)

validator.check_int(uint_crd_f_shape[0], n, Rel.EQ, "uint_crd_f_shape[0]", cls_name)
validator.check_int(uint_crd_f_shape[1], 3, Rel.EQ, "uint_crd_f_shape[1]", cls_name)
validator.check_int(ljtype_shape[0], n, Rel.EQ, "LJtype_shape", cls_name)
validator.check_int(charge_shape[0], n, Rel.EQ, "charge_shape", cls_name)
validator.check_int(boxlength_f_shape[0], 3, Rel.EQ, "boxlength_f_shape", cls_name)
validator.check_int(LJ_type_B_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
validator.check_int(lj_type_b_shape[0], q, Rel.EQ, "LJ_type_B_shape", cls_name)
m = self.dihedral_14_numbers
validator.check_int(a_14_shape[0], m, Rel.EQ, "a_14_shape", cls_name)
validator.check_int(b_14_shape[0], m, Rel.EQ, "b_14_shape", cls_name)
@@ -1539,7 +1539,7 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
return ljtype_shape

def infer_dtype(self, uint_crd_f_dtype, ljtype_dtype, charge_dtype, boxlength_f_type, a_14_type, b_14_type,
lj_scale_factor_type, LJ_type_A_type, LJ_type_B_type):
lj_scale_factor_type, lj_type_a_type, lj_type_b_type):
validator.check_tensor_dtype_valid('uint_crd_f', uint_crd_f_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge_dtype, [mstype.float32], self.name)
@@ -1548,10 +1548,10 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('b_14', b_14_type, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('lj_scale_factor', lj_scale_factor_type, [mstype.float32],
self.name)
validator.check_tensor_dtype_valid('LJ_type_A', LJ_type_A_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', LJ_type_B_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_A', lj_type_a_type, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('LJ_type_B', lj_type_b_type, [mstype.float32], self.name)

return LJ_type_A_type
return lj_type_a_type


class Dihedral14CFEnergy(PrimitiveWithInfer):
@@ -2123,10 +2123,10 @@ class LJEnergy(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)
self.add_prim_attr('cutoff_square', self.cutoff_square)

def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
cls_name = self.name
n = self.atom_numbers
q = d_LJ_A[0]
q = d_lj_a[0]
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
@@ -2134,7 +2134,7 @@ class LJEnergy(PrimitiveWithInfer):
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)

validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
@@ -2145,18 +2145,18 @@ class LJEnergy(PrimitiveWithInfer):
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return charge

def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
return charge


@@ -2209,10 +2209,10 @@ class LJForce(PrimitiveWithInfer):
self.add_prim_attr('atom_numbers', self.atom_numbers)
self.add_prim_attr('cutoff_square', self.cutoff_square)

def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
cls_name = self.name
n = self.atom_numbers
q = d_LJ_A[0]
q = d_lj_a[0]
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
@@ -2220,7 +2220,7 @@ class LJForce(PrimitiveWithInfer):
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)

validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
@@ -2231,18 +2231,18 @@ class LJForce(PrimitiveWithInfer):
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return uint_crd

def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
return charge


@@ -2293,10 +2293,10 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
self.add_prim_attr('cutoff', self.cutoff)
self.add_prim_attr('pme_beta', self.pme_beta)

def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_shape(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
cls_name = self.name
n = self.atom_numbers
q = d_LJ_A[0]
q = d_lj_a[0]
validator.check_int(len(uint_crd), 2, Rel.EQ, "uint_crd_dim", cls_name)
validator.check_int(len(ljtype), 1, Rel.EQ, "LJtype_dim", cls_name)
validator.check_int(len(charge), 1, Rel.EQ, "charge_dim", cls_name)
@@ -2304,7 +2304,7 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
validator.check_int(len(nl_numbers), 1, Rel.EQ, "nl_numbers_dim", cls_name)
validator.check_int(len(nl_serial), 2, Rel.EQ, "nl_serial_dim", cls_name)
validator.check_int(len(scaler), 1, Rel.EQ, "scaler_dim", cls_name)
validator.check_int(len(d_LJ_B), 1, Rel.EQ, "d_LJ_B_dim", cls_name)
validator.check_int(len(d_lj_b), 1, Rel.EQ, "d_LJ_B_dim", cls_name)

validator.check_int(uint_crd[0], n, Rel.EQ, "uint_crd_shape[0]", cls_name)
validator.check_int(uint_crd[1], 3, Rel.EQ, "uint_crd_shape[1]", cls_name)
@@ -2315,18 +2315,18 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
validator.check_int(nl_serial[0], n, Rel.EQ, "nl_serial_shape[0]", cls_name)
validator.check_int(nl_serial[1], 800, Rel.LE, "nl_serial_shape[1]", cls_name)
validator.check_int(scaler[0], 3, Rel.EQ, "scaler_shape", cls_name)
validator.check_int(d_LJ_B[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
validator.check_int(d_lj_b[0], q, Rel.EQ, "d_LJ_B_shape[0]", cls_name)
return uint_crd

def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_LJ_A, d_LJ_B):
def infer_dtype(self, uint_crd, ljtype, charge, scaler, nl_numbers, nl_serial, d_lj_a, d_lj_b):
validator.check_tensor_dtype_valid('uint_crd', uint_crd, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('LJtype', ljtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('charge', charge, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('scaler', scaler, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('nl_numbers', nl_numbers, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('nl_serial', nl_serial, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_LJ_A, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_LJ_B, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_A', d_lj_a, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('d_LJ_B', d_lj_b, [mstype.float32], self.name)
return charge




Loading…
Cancel
Save