Browse Source

sponge update neighbourlistupdate

tags/v1.5.0-rc1
q00596439 4 years ago
parent
commit
782b373f61
7 changed files with 328 additions and 171 deletions
  1. +109
    -132
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cu
  2. +11
    -13
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cuh
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_kernel.cc
  4. +16
    -21
      mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_kernel.h
  5. +3
    -2
      mindspore/ops/operations/__init__.py
  6. +186
    -0
      mindspore/ops/operations/sponge_ops.py
  7. +1
    -1
      mindspore/ops/operations/sponge_update_ops.py

mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_new_impl.cu → mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cu View File

@@ -13,31 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_new_impl.cuh"
#include <stdio.h>
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cuh"
#include <vector>

__device__ __host__ VECTOR operator-(const VECTOR &vecb) {
VECTOR vec;
vec.x = -vecb.x;
vec.y = -vecb.y;
vec.z = -vecb.z;
return vec;
}

__device__ __host__ VECTOR Get_Periodic_Displacement(const VECTOR vec_a, const VECTOR vec_b, const VECTOR box_length) {
VECTOR dr;
// dr = vec_a - vec_b;
dr.x = vec_a.x - vec_b.x;
dr.y = vec_a.y - vec_b.y;
dr.x = vec_a.z - vec_b.z;

dr.x = dr.x - floorf(dr.x / box_length.x + 0.5) * box_length.x;
dr.y = dr.y - floorf(dr.y / box_length.y + 0.5) * box_length.y;
dr.z = dr.z - floorf(dr.z / box_length.z + 0.5) * box_length.z;
return dr;
}

__global__ void Copy_List(const int element_numbers, const int *origin_list, int *list) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
@@ -55,20 +32,12 @@ __global__ void Crd_To_Uint_Crd(const int atom_numbers, float *scale_factor, con
UNSIGNED_INT_VECTOR *uint_crd) {
int atom_i = blockDim.x * blockIdx.x + threadIdx.x;
if (atom_i < atom_numbers) {
INT_VECTOR tempi;
VECTOR temp = crd[atom_i];

temp.x *= scale_factor[0];
temp.y *= scale_factor[1];
temp.z *= scale_factor[2];

tempi.int_x = temp.x;
tempi.int_y = temp.y;
tempi.int_z = temp.z;

uint_crd[atom_i].uint_x = (tempi.int_x << 2);
uint_crd[atom_i].uint_y = (tempi.int_y << 2);
uint_crd[atom_i].uint_z = (tempi.int_z << 2);
uint_crd[atom_i].uint_x = crd[atom_i].x * scale_factor[0];
uint_crd[atom_i].uint_y = crd[atom_i].y * scale_factor[1];
uint_crd[atom_i].uint_z = crd[atom_i].z * scale_factor[2];
uint_crd[atom_i].uint_x = uint_crd[atom_i].uint_x << 1;
uint_crd[atom_i].uint_y = uint_crd[atom_i].uint_y << 1;
uint_crd[atom_i].uint_z = uint_crd[atom_i].uint_z << 1;
}
}

@@ -108,6 +77,7 @@ __global__ void Crd_Periodic_Map(const int atom_numbers, VECTOR *crd, const floa
} else {
crd[atom_i].y = crd[atom_i].y + box_length[1];
}

if (crd[atom_i].z >= 0) {
if (crd[atom_i].z < box_length[2]) {
} else {
@@ -225,21 +195,6 @@ __global__ void Is_need_refresh_neighbor_list_cuda(const int atom_numbers, const
}
}

__global__ void Is_need_refresh_neighbor_list_cuda(const int atom_numbers, const VECTOR *crd, const VECTOR *old_crd,
const VECTOR *box_length, const float half_skin_square,
int *need_refresh_flag) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < atom_numbers) {
VECTOR r1 = crd[i];
VECTOR r2 = old_crd[i];
r1 = Get_Periodic_Displacement(r1, r2, box_length[0]);
float r1_2 = r1.x * r1.x + r1.y * r1.y + r1.z * r1.z;
if (r1_2 > half_skin_square) {
atomicExch(&need_refresh_flag[0], 1);
}
}
}

__global__ void Delete_Excluded_Atoms_Serial_In_Neighbor_List(const int atom_numbers, NEIGHBOR_LIST *nl,
const int *excluded_list_start, const int *excluded_list,
const int *excluded_atom_numbers) {
@@ -287,18 +242,27 @@ void Refresh_Neighbor_List(int *refresh_sign, const int thread, const int atom_n
int *excluded_list_start, int *excluded_list, int *excluded_numbers,
float cutoff_skin_square, int grid_numbers, float *grid_length_inverse, int *grid_N, int Nxy,
cudaStream_t stream) {
std::vector<int> h_refresh_sign(1);
cudaMemcpyAsync(h_refresh_sign.data(), refresh_sign, sizeof(int), cudaMemcpyDeviceToHost, stream);
if (h_refresh_sign[0] == 1) {
if (refresh_sign[0] == 1) {
VECTOR trans_vec = {-skin, -skin, -skin};
Clear_Grid_Bucket<<<ceilf(static_cast<float>(grid_numbers) / thread), thread, 0, stream>>>(
grid_numbers, atom_numbers_in_grid_bucket, bucket);

Vector_Translation<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0, stream>>>(atom_numbers, crd,
trans_vec);

Crd_Periodic_Map<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0, stream>>>(atom_numbers, crd,
box_length);

Find_Atom_In_Grid_Serial<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0, stream>>>(
atom_numbers, grid_length_inverse, crd, grid_N, Nxy, atom_in_grid_serial);

trans_vec.x = -trans_vec.x;
trans_vec.y = -trans_vec.y;
trans_vec.z = -trans_vec.z;

Vector_Translation<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0, stream>>>(atom_numbers, crd,
trans_vec);

Copy_List<<<ceilf(static_cast<float>(3. * atom_numbers) / thread), thread, 0, stream>>>(
3 * atom_numbers, reinterpret_cast<float *>(crd), reinterpret_cast<float *>(old_crd));

@@ -315,10 +279,40 @@ void Refresh_Neighbor_List(int *refresh_sign, const int thread, const int atom_n
Delete_Excluded_Atoms_Serial_In_Neighbor_List<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0,
stream>>>(atom_numbers, d_nl, excluded_list_start, excluded_list,
excluded_numbers);
h_refresh_sign[0] = 0;
refresh_sign[0] = 0;
}
}

void Refresh_Neighbor_List_First_Time(int *refresh_sign, const int thread, const int atom_numbers, VECTOR *crd,
VECTOR *old_crd, UNSIGNED_INT_VECTOR *uint_crd, float *crd_to_uint_crd_cof,
float *uint_dr_to_dr_cof, int *atom_in_grid_serial, const float skin,
float *box_length, const GRID_POINTER *gpointer, GRID_BUCKET *bucket,
int *atom_numbers_in_grid_bucket, NEIGHBOR_LIST *d_nl, int *excluded_list_start,
int *excluded_list, int *excluded_numbers, float cutoff_skin_square,
int grid_numbers, float *grid_length_inverse, int *grid_N, int Nxy,
cudaStream_t stream) {
VECTOR trans_vec = {skin, skin, skin};
Clear_Grid_Bucket<<<ceilf(static_cast<float>(grid_numbers) / 32), 32, 0, stream>>>(
grid_numbers, atom_numbers_in_grid_bucket, bucket);
Crd_Periodic_Map<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd, box_length);
Find_Atom_In_Grid_Serial<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, grid_length_inverse, crd, grid_N, Nxy, atom_in_grid_serial);
Vector_Translation<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd, trans_vec);
Copy_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 32), 32, 0, stream>>>(
3 * atom_numbers, reinterpret_cast<float *>(crd), reinterpret_cast<float *>(old_crd));
Put_Atom_In_Grid_Bucket<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, atom_in_grid_serial, bucket, atom_numbers_in_grid_bucket);
Crd_To_Uint_Crd<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd_to_uint_crd_cof,
crd, uint_crd);

Find_atom_neighbors<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0, stream>>>(
atom_numbers, uint_crd, uint_dr_to_dr_cof, atom_in_grid_serial, gpointer, bucket, atom_numbers_in_grid_bucket, d_nl,
cutoff_skin_square);
Delete_Excluded_Atoms_Serial_In_Neighbor_List<<<ceilf(static_cast<float>(atom_numbers) / thread), thread, 0,
stream>>>(atom_numbers, d_nl, excluded_list_start, excluded_list,
excluded_numbers);
}

__global__ void construct_neighbor_list_kernel(int atom_numbers, int max_neighbor_numbers, int *nl_atom_numbers,
int *nl_atom_serial, NEIGHBOR_LIST *nl) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < atom_numbers; i += gridDim.x * blockDim.x) {
@@ -333,39 +327,15 @@ void Construct_Neighbor_List(int atom_numbers, int max_neighbor_numbers, int *nl
atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl);
}

__global__ void copy_neighbor_list_atom_number(int atom_numbers, int max_neighbor_numbers, NEIGHBOR_LIST *nl,
int *nl_atom_numbers, int *nl_atom_serial) {
int i, j;
for (i = blockIdx.x * blockDim.x + threadIdx.x; i < atom_numbers; i += gridDim.x * blockDim.x) {
__global__ void copy_neighbor_list_atom_number(int atom_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < atom_numbers; i += gridDim.x * blockDim.x) {
nl_atom_numbers[i] = nl[i].atom_numbers;
for (j = blockIdx.y * blockDim.y + threadIdx.y; j < max_neighbor_numbers; j += gridDim.y * blockDim.y) {
if (j < nl_atom_numbers[i]) {
nl_atom_serial[i * max_neighbor_numbers + j] = nl[i].atom_serial[j];
} else {
nl_atom_serial[i * max_neighbor_numbers + j] = 0;
}
}
}
}

__global__ void Reset_List(const int element_numbers, int *list, const int replace_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = replace_element;
}
}

__global__ void Reset_List(const int element_numbers, float *list, const float replace_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
list[i] = replace_element;
}
}

void CopyNeighborListAtomNumber(int atom_numbers, int max_neighbor_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers,
int *nl_atom_serial, cudaStream_t stream) {
copy_neighbor_list_atom_number<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, max_neighbor_numbers, nl, nl_atom_numbers, nl_atom_serial);
void CopyNeighborListAtomNumber(int atom_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers, cudaStream_t stream) {
copy_neighbor_list_atom_number<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, nl,
nl_atom_numbers);
}

void Refresh_Neighbor_List_No_Check(int grid_numbers, int atom_numbers, float skin, int Nxy, float cutoff_skin_square,
@@ -375,13 +345,22 @@ void Refresh_Neighbor_List_No_Check(int grid_numbers, int atom_numbers, float sk
UNSIGNED_INT_VECTOR *uint_crd, float *uint_dr_to_dr_cof, GRID_POINTER *gpointer,
NEIGHBOR_LIST *d_nl, int *excluded_list_start, int *excluded_list,
int *excluded_numbers, cudaStream_t stream) {
VECTOR trans_vec = {-skin, -skin, -skin};

Clear_Grid_Bucket<<<ceilf(static_cast<float>(grid_numbers) / 32), 32, 0, stream>>>(
grid_numbers, atom_numbers_in_grid_bucket, bucket);

Vector_Translation<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd, trans_vec);

Crd_Periodic_Map<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd, box_length);

Find_Atom_In_Grid_Serial<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
atom_numbers, grid_length_inverse, crd, grid_N, Nxy, atom_in_grid_serial);
trans_vec.x = -trans_vec.x;
trans_vec.y = -trans_vec.y;
trans_vec.z = -trans_vec.z;
Vector_Translation<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(atom_numbers, crd, trans_vec);

cudaMemcpyAsync(old_crd, crd, sizeof(VECTOR) * atom_numbers, cudaMemcpyDeviceToDevice, stream);

Put_Atom_In_Grid_Bucket<<<ceilf(static_cast<float>(atom_numbers) / 32), 32, 0, stream>>>(
@@ -405,53 +384,51 @@ __global__ void Mul_half(float *src, float *dst) {
}
}

__global__ void Mul_quarter(float *src, float *dst) {
int index = threadIdx.x;
if (index < 3) {
dst[index] = src[index] * 0.25;
}
}

int refresh_count = 0;

void Neighbor_List_Update_New(int grid_numbers, int atom_numbers, int *d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square,
float cutoff_with_skin_square, int *grid_N, float *box_length,
int *atom_numbers_in_grid_bucket, float *grid_length_inverse, int *atom_in_grid_serial,
GRID_BUCKET *bucket, float *crd, float *old_crd, float *crd_to_uint_crd_cof,
float *half_crd_to_uint_crd_cof, unsigned int *uint_crd, float *uint_dr_to_dr_cof,
GRID_POINTER *gpointer, NEIGHBOR_LIST *d_nl, int *excluded_list_start, int *excluded_list,
int *excluded_numbers, float half_skin_square, int *is_need_refresh_neighbor_list,
int forced_update, int forced_check, cudaStream_t stream) {
if (forced_update) {
Mul_quarter<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List_No_Check(
grid_numbers, atom_numbers, skin, Nxy, cutoff_square, grid_N, box_length, atom_numbers_in_grid_bucket,
grid_length_inverse, atom_in_grid_serial, bucket, reinterpret_cast<VECTOR *>(crd),
reinterpret_cast<VECTOR *>(old_crd), half_crd_to_uint_crd_cof, reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
uint_dr_to_dr_cof, gpointer, d_nl, excluded_list_start, excluded_list, excluded_numbers, stream);

} else if (refresh_interval > 0 && !forced_check) {
if (refresh_count % refresh_interval == 0) {
Mul_quarter<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List_No_Check(grid_numbers, atom_numbers, skin, Nxy, cutoff_square, grid_N, box_length,
atom_numbers_in_grid_bucket, grid_length_inverse, atom_in_grid_serial, bucket,
reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd),
half_crd_to_uint_crd_cof, reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
uint_dr_to_dr_cof, gpointer, d_nl, excluded_list_start, excluded_list,
excluded_numbers, stream);
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int *d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square, float cutoff_with_skin_square,
int *grid_N, float *box_length, int *atom_numbers_in_grid_bucket, float *grid_length_inverse,
int *atom_in_grid_serial, GRID_BUCKET *bucket, float *crd, float *old_crd,
float *crd_to_uint_crd_cof, float *half_crd_to_uint_crd_cof, unsigned int *uint_crd,
float *uint_dr_to_dr_cof, GRID_POINTER *gpointer, NEIGHBOR_LIST *d_nl,
int *excluded_list_start, int *excluded_list, int *excluded_numbers, float half_skin_square,
int *is_need_refresh_neighbor_list, cudaStream_t stream) {
if (not_first_time) {
if (refresh_interval > 0) {
std::vector<int> refresh_count_list(1);
cudaMemcpyAsync(refresh_count_list.data(), d_refresh_count, sizeof(int), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
int refresh_count = refresh_count_list[0];

if (refresh_count % refresh_interval == 0) {
Mul_half<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List_No_Check(grid_numbers, atom_numbers, skin, Nxy, cutoff_square, grid_N, box_length,
atom_numbers_in_grid_bucket, grid_length_inverse, atom_in_grid_serial, bucket,
reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd),
half_crd_to_uint_crd_cof, reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
uint_dr_to_dr_cof, gpointer, d_nl, excluded_list_start, excluded_list,
excluded_numbers, stream);
}
refresh_count += 1;
cudaMemcpyAsync(d_refresh_count, &refresh_count, sizeof(int), cudaMemcpyHostToDevice, stream);
} else {
Is_need_refresh_neighbor_list_cuda<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd), half_skin_square,
is_need_refresh_neighbor_list);
Mul_half<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List(is_need_refresh_neighbor_list, 32, atom_numbers, reinterpret_cast<VECTOR *>(crd),
reinterpret_cast<VECTOR *>(old_crd), reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
half_crd_to_uint_crd_cof, uint_dr_to_dr_cof, atom_in_grid_serial, skin, box_length,
gpointer, bucket, atom_numbers_in_grid_bucket, d_nl, excluded_list_start, excluded_list,
excluded_numbers, cutoff_with_skin_square, grid_numbers, grid_length_inverse, grid_N, Nxy,
stream);
}
refresh_count += 1;
} else {
Is_need_refresh_neighbor_list_cuda<<<ceilf(static_cast<float>(atom_numbers) / 128), 128, 0, stream>>>(
atom_numbers, reinterpret_cast<VECTOR *>(crd), reinterpret_cast<VECTOR *>(old_crd),
reinterpret_cast<VECTOR *>(box_length), half_skin_square, is_need_refresh_neighbor_list);
Mul_quarter<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List(is_need_refresh_neighbor_list, 32, atom_numbers, reinterpret_cast<VECTOR *>(crd),
reinterpret_cast<VECTOR *>(old_crd), reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd),
half_crd_to_uint_crd_cof, uint_dr_to_dr_cof, atom_in_grid_serial, skin, box_length, gpointer,
bucket, atom_numbers_in_grid_bucket, d_nl, excluded_list_start, excluded_list,
excluded_numbers, cutoff_with_skin_square, grid_numbers, grid_length_inverse, grid_N, Nxy,
stream);
Mul_half<<<1, 3, 0, stream>>>(crd_to_uint_crd_cof, half_crd_to_uint_crd_cof);
Refresh_Neighbor_List_First_Time(
is_need_refresh_neighbor_list, 32, atom_numbers, reinterpret_cast<VECTOR *>(crd),
reinterpret_cast<VECTOR *>(old_crd), reinterpret_cast<UNSIGNED_INT_VECTOR *>(uint_crd), half_crd_to_uint_crd_cof,
uint_dr_to_dr_cof, atom_in_grid_serial, skin, box_length, gpointer, bucket, atom_numbers_in_grid_bucket, d_nl,
excluded_list_start, excluded_list, excluded_numbers, cutoff_with_skin_square, grid_numbers, grid_length_inverse,
grid_N, Nxy, stream);
}
}

mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_new_impl.cuh → mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cuh View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NEIGHBOR_LIST_NEW_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NEIGHBOR_LIST_NEW_IMPL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NEIGHBOR_LIST_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_NEIGHBOR_LIST_IMPL_H_

struct VECTOR {
float x;
@@ -46,17 +46,15 @@ struct GRID_POINTER {
void Construct_Neighbor_List(int grid_numbers, int max_neighbor_numbers, int *nl_atom_numbers, int *nl_atom_serial,
NEIGHBOR_LIST *nl, cudaStream_t stream);

void CopyNeighborListAtomNumber(int atom_numbers, int max_neighbor_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers,
int *nl_atom_serial, cudaStream_t stream);
void CopyNeighborListAtomNumber(int atom_numbers, NEIGHBOR_LIST *nl, int *nl_atom_numbers, cudaStream_t stream);

void Neighbor_List_Update_New(int grid_numbers, int atom_numbers, int *d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square,
float cutoff_with_skin_square, int *grid_N, float *box_length,
int *atom_numbers_in_grid_bucket, float *grid_length_inverse, int *atom_in_grid_serial,
GRID_BUCKET *bucket, float *crd, float *old_crd, float *crd_to_uint_crd_cof,
float *half_crd_to_uint_crd_cof, unsigned int *uint_crd, float *uint_dr_to_dr_cof,
GRID_POINTER *gpointer, NEIGHBOR_LIST *d_nl, int *excluded_list_start, int *excluded_list,
int *excluded_numbers, float half_skin_square, int *is_need_refresh_neighbor_list,
int forced_update, int forced_check, cudaStream_t stream);
void Neighbor_List_Update(int grid_numbers, int atom_numbers, int* d_refresh_count, int refresh_interval,
int not_first_time, float skin, int Nxy, float cutoff_square, float cutoff_with_skin_square,
int *grid_N, float *box_length, int *atom_numbers_in_grid_bucket, float *grid_length_inverse,
int *atom_in_grid_serial, GRID_BUCKET *bucket, float *crd, float *old_crd,
float *crd_to_uint_crd_cof, float *half_crd_to_uint_crd_cof, unsigned int *uint_crd,
float *uint_dr_to_dr_cof, GRID_POINTER *gpointer, NEIGHBOR_LIST *d_nl,
int *excluded_list_start, int *excluded_list, int *excluded_numbers, float half_skin_square,
int *is_need_refresh_neighbor_list, cudaStream_t stream);

#endif

mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_new_kernel.cc → mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_kernel.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_new_kernel.h"
#include "backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_kernel.h"

namespace mindspore {
namespace kernel {
@@ -40,6 +40,6 @@ MS_REG_GPU_KERNEL_TWO(NeighborListUpdate,
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
NeighborListUpdateNewGpuKernel, int, float)
NeighborListUpdateGpuKernel, int, float)
} // namespace kernel
} // namespace mindspore

mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_new_kernel.h → mindspore/ccsrc/backend/kernel_compiler/gpu/sponge/neighbor_list/neighbor_list_update_kernel.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_NEIGHBOR_LIST_UPDATE_NEW_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_NEIGHBOR_LIST_UPDATE_NEW_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_NEIGHBOR_LIST_UPDATE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_NEIGHBOR_LIST_UPDATE_KERNEL_H_

#include <cuda_runtime_api.h>
#include <vector>
@@ -24,21 +24,21 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_new_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/neighbor_list/neighbor_list_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T, typename T1>
class NeighborListUpdateNewGpuKernel : public GpuKernel {
class NeighborListUpdateGpuKernel : public GpuKernel {
public:
NeighborListUpdateNewGpuKernel() : skin(2.0), cutoff(9.0), max_atom_in_grid_numbers(64), max_neighbor_numbers(800) {}
~NeighborListUpdateNewGpuKernel() override = default;
NeighborListUpdateGpuKernel() : skin(2.0), cutoff(10.0), max_atom_in_grid_numbers(64), max_neighbor_numbers(800) {}
~NeighborListUpdateGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
grid_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "grid_numbers"));
atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers"));
refresh_interval = static_cast<int>(GetAttr<int64_t>(kernel_node, "refresh_interval"));
not_first_time = static_cast<int>(GetAttr<int64_t>(kernel_node, "not_first_time"));
nxy = static_cast<int>(GetAttr<int64_t>(kernel_node, "nxy"));
Nxy = static_cast<int>(GetAttr<int64_t>(kernel_node, "Nxy"));
excluded_atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "excluded_atom_numbers"));

cutoff_square = static_cast<float>(GetAttr<float>(kernel_node, "cutoff_square"));
@@ -46,8 +46,6 @@ class NeighborListUpdateNewGpuKernel : public GpuKernel {
cutoff_with_skin = static_cast<float>(GetAttr<float>(kernel_node, "cutoff_with_skin"));
half_cutoff_with_skin = static_cast<float>(GetAttr<float>(kernel_node, "half_cutoff_with_skin"));
cutoff_with_skin_square = static_cast<float>(GetAttr<float>(kernel_node, "cutoff_with_skin_square"));
forced_update = static_cast<int>(GetAttr<int64_t>(kernel_node, "forced_update"));
forced_check = static_cast<int>(GetAttr<int64_t>(kernel_node, "forced_check"));
h_bucket.resize(grid_numbers);
h_gpointer.resize(grid_numbers);
InitSizeLists();
@@ -64,7 +62,7 @@ class NeighborListUpdateNewGpuKernel : public GpuKernel {
auto bucket = GetDeviceAddress<int>(inputs, 1);
auto crd = GetDeviceAddress<float>(inputs, 2);
auto box_length = GetDeviceAddress<float>(inputs, 3);
auto grid_n = GetDeviceAddress<int>(inputs, 4);
auto grid_N = GetDeviceAddress<int>(inputs, 4);
auto grid_length_inverse = GetDeviceAddress<float>(inputs, 5);
auto atom_in_grid_serial = GetDeviceAddress<int>(inputs, 6);
auto old_crd = GetDeviceAddress<float>(inputs, 7);
@@ -101,14 +99,13 @@ class NeighborListUpdateNewGpuKernel : public GpuKernel {
Construct_Neighbor_List(atom_numbers, max_neighbor_numbers, nl_atom_numbers, nl_atom_serial, nl,
reinterpret_cast<cudaStream_t>(stream_ptr));

Neighbor_List_Update_New(grid_numbers, atom_numbers, d_refresh_count, refresh_interval, not_first_time, skin, nxy,
cutoff_square, cutoff_with_skin_square, grid_n, box_length, atom_numbers_in_grid_bucket,
grid_length_inverse, atom_in_grid_serial, d_bucket, crd, old_crd, crd_to_uint_crd_cof,
half_crd_to_uint_crd_cof, uint_crd, uint_dr_to_dr_cof, d_gpointer, nl, excluded_list_start,
excluded_list, excluded_numbers, half_skin_square, need_refresh_flag, forced_update,
forced_check, reinterpret_cast<cudaStream_t>(stream_ptr));
CopyNeighborListAtomNumber(atom_numbers, max_neighbor_numbers, nl, nl_atom_numbers, nl_atom_serial,
reinterpret_cast<cudaStream_t>(stream_ptr));
Neighbor_List_Update(grid_numbers, atom_numbers, d_refresh_count, refresh_interval, not_first_time, skin, Nxy,
cutoff_square, cutoff_with_skin_square, grid_N, box_length, atom_numbers_in_grid_bucket,
grid_length_inverse, atom_in_grid_serial, d_bucket, crd, old_crd, crd_to_uint_crd_cof,
half_crd_to_uint_crd_cof, uint_crd, uint_dr_to_dr_cof, d_gpointer, nl, excluded_list_start,
excluded_list, excluded_numbers, half_skin_square, need_refresh_flag,
reinterpret_cast<cudaStream_t>(stream_ptr));
CopyNeighborListAtomNumber(atom_numbers, nl, nl_atom_numbers, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

@@ -154,7 +151,7 @@ class NeighborListUpdateNewGpuKernel : public GpuKernel {
int atom_numbers;
int grid_numbers;
int refresh_interval;
int nxy;
int Nxy;
int max_atom_in_grid_numbers;
int max_neighbor_numbers;
int excluded_atom_numbers;
@@ -163,8 +160,6 @@ class NeighborListUpdateNewGpuKernel : public GpuKernel {
float cutoff_with_skin;
float half_cutoff_with_skin;
float cutoff_with_skin_square;
int forced_update;
int forced_check;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

+ 3
- 2
mindspore/ops/operations/__init__.py View File

@@ -110,14 +110,14 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy,
GetCenterOfGeometry, MDTemperature, MDIterationLeapFrogLiujian,
CrdToUintCrd, MDIterationSetupRandState, TransferCrd, FFT3D, IFFT3D)
CrdToUintCrd, MDIterationSetupRandState, TransferCrd, FFT3D, IFFT3D, NeighborListUpdate)
from .sponge_update_ops import (v0coordinaterefresh, v1coordinaterefresh, v2coordinaterefresh, v3coordinaterefresh,
v0forceredistribute, v1forceredistribute, v2forceredistribute, v3forceredistribute,
restrainenergy, restrainforcewithatomenergyandvirial, constrainforcecyclewithvirial,
refreshuintcrd, lastcrdtodr, refreshcrdvel, calculatenowrapcrd, refreshboxmaptimes,
totalc6get, copyfrctosystemgrad, CrdToUintCrdQuarter,
MDIterationLeapFrogLiujianWithMaxVel, GetCenterOfMass, MapCenterOfMass,
NeighborListUpdate, MDIterationLeapFrog,
NeighborListUpdateNew, MDIterationLeapFrog,
MDIterationLeapFrogWithMaxVel, MDIterationGradientDescent,
BondForceWithAtomEnergyAndVirial, ConstrainForceCycle)
from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
@@ -529,6 +529,7 @@ __all__ = [
"BufferAppend",
"BufferGetItem",
"BufferSample",
"NeighborListUpdateNew",
]

__all__.sort()

+ 186
- 0
mindspore/ops/operations/sponge_ops.py View File

@@ -3045,3 +3045,189 @@ class IFFT3D(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('input_real', input_real_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_imag', input_imag_dtype, mstype.number_type, self.name)
return input_real_dtype

class NeighborListUpdate(PrimitiveWithInfer):
"""
Update (or construct if first time) the Verlet neighbor list for the
calculation of short-ranged force. Assume the number of atoms is N,
the number of grids divided is G, the maximum number of atoms in one
grid is M, the maximum number of atoms in single atom's neighbor list
is L, and the number of total atom in excluded list is E.

Args:
grid_numbers(int32): the total number of grids divided.
not_first_time(int32): whether to construct the neighbor
list first time or not.
Nxy(int32): the total number of grids divided in xy plane.
excluded_atom_numbers(int32): the total atom numbers in the excluded list.
cutoff(float32): the cutoff distance for short-range force calculation.
skin(float32): the overflow value of cutoff to maintain a neighbor list.
cutoff_square(float32): the suqare value of cutoff.
half_skin_square(float32): skin*skin/4, indicates the maximum
square value of the distance atom allowed to move between two updates.
cutoff_with_skin(float32): cutoff + skin, indicates the
radius of the neighbor list for each atom.
half_cutoff_with_skin(float32): cutoff_with_skin/2.
cutoff_with_skin_square(float32): the square value of cutoff_with_skin.
refresh_interval(int32): the number of iteration steps between two updates of neighbor list.
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid.

Inputs:
- **atom_numbers_in_grid_bucket** (Tensor, int32) - [G,], the number of atoms in each grid bucket.
- **bucket** (Tensor, int32) - (Tensor,int32) - [G, M], the atom indices in each grid bucket.
- **crd** (Tensor, float32) - [N,], the coordinates of each atom.
- **box_length** (Tensor, float32) - [3,], the length of 3 dimensions of the simulation box.
- **grid_N** (Tensor, int32) - [3,], the number of grids divided of 3 dimensions of the simulation box.
- **grid_length_inverse** (float32) - the inverse value of grid length.
- **atom_in_grid_serial** (Tensor, int32) - [N,], the grid index for each atom.
- **old_crd** (Tensor, float32) - [N, 3], the coordinates before update of each atom.
- **crd_to_uint_crd_cof** (Tensor, float32) - [3,], the scale factor
between the unsigned int value and the real space coordinates.
- **uint_crd** (Tensor, uint32) - [N, 3], the unsigned int coordinates value fo each atom.
- **gpointer** (Tensor, int32) - [G, 125], the 125 nearest neighbor grids (including self) of each grid.
G is the number of nearest neighbor grids.
- **nl_atom_numbers** (Tensor, int32) - [N,], the number of atoms in neighbor list of each atom.
- **nl_atom_serial** (Tensor, int32) - [N, L], the indices of atoms in neighbor list of each atom.
- **uint_dr_to_dr_cof** (Tensor, float32) - [3,], the scale factor between
the real space coordinates and the unsigned int value.
- **excluded_list_start** (Tensor, int32) - [N,], the start excluded index in excluded list for each atom.
- **excluded_numbers** (Tensor, int32) - [N,], the number of atom excluded in excluded list for each atom.
- **excluded_list** (Tensor, int32) - [E,], the contiguous join of excluded list of each atom.
- **need_refresh_flag** (Tensor, int32) - [N,], whether the neighbor list of each atom need update or not.
- **refresh_count** (Tensor, int32) - [1,], count how many iteration steps have passed since last update.

Outputs:
- **res** (float32)

Supported Platforms:
``GPU``
"""

@prim_attr_register
def __init__(self, grid_numbers, atom_numbers, not_first_time, Nxy, excluded_atom_numbers,
cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800):
self.grid_numbers = grid_numbers
self.atom_numbers = atom_numbers
self.refresh_interval = refresh_interval
self.not_first_time = not_first_time
self.cutoff = cutoff
self.skin = skin
self.max_atom_in_grid_numbers = max_atom_in_grid_numbers
self.Nxy = Nxy
self.excluded_atom_numbers = excluded_atom_numbers
self.cutoff_square = cutoff_square
self.half_skin_square = half_skin_square
self.cutoff_with_skin = cutoff_with_skin
self.half_cutoff_with_skin = half_cutoff_with_skin
self.cutoff_with_skin_square = cutoff_with_skin_square
self.max_neighbor_numbers = max_neighbor_numbers
self.init_prim_io_names(
inputs=['atom_numbers_in_grid_bucket', 'bucket', 'crd', 'box_length', 'grid_N', 'grid_length_inverse',
'atom_in_grid_serial', 'old_crd', 'crd_to_uint_crd_cof', 'uint_crd', 'gpointer', 'nl_atom_numbers',
'nl_atom_serial', 'uint_dr_to_dr_cof', 'excluded_list_start', 'excluded_list', 'excluded_numbers',
'need_refresh_flag', 'refresh_count'], outputs=['res'])

self.add_prim_attr('grid_numbers', self.grid_numbers)
self.add_prim_attr('atom_numbers', self.atom_numbers)
self.add_prim_attr('refresh_interval', self.refresh_interval)
self.add_prim_attr('not_first_time', self.not_first_time)
self.add_prim_attr('cutoff', self.cutoff)
self.add_prim_attr('skin', self.skin)
self.add_prim_attr('max_atom_in_grid_numbers', self.max_atom_in_grid_numbers)
self.add_prim_attr('Nxy', self.Nxy)
self.add_prim_attr('excluded_atom_numbers', self.excluded_atom_numbers)
self.add_prim_attr('cutoff_square', self.cutoff_square)
self.add_prim_attr('half_skin_square', self.half_skin_square)
self.add_prim_attr('cutoff_with_skin', self.cutoff_with_skin)
self.add_prim_attr('half_cutoff_with_skin', self.half_cutoff_with_skin)
self.add_prim_attr('cutoff_with_skin_square', self.cutoff_with_skin_square)

def infer_shape(self, atom_numbers_in_grid_bucket_shape, bucket_shape, crd_shape, box_length_shape, grid_N_shape,
grid_length_inverse_shape, atom_in_grid_serial_shape, old_crd_shape, crd_to_uint_crd_cof_shape,
uint_crd_shape, gpointer_shape, nl_atom_numbers_shape, nl_atom_serial_shape,
uint_dr_to_dr_cof_shape, excluded_list_start_shape, excluded_list_shape, excluded_numbers_shape,
need_refresh_flag_shape, refresh_count_shape):
assert len(atom_numbers_in_grid_bucket_shape) == 1
assert len(bucket_shape) == 2
assert len(crd_shape) == 2
assert len(box_length_shape) == 1
assert len(grid_N_shape) == 1
assert len(grid_length_inverse_shape) == 1
assert len(atom_in_grid_serial_shape) == 1
assert len(old_crd_shape) == 2
assert len(crd_to_uint_crd_cof_shape) == 1
assert len(uint_crd_shape) == 2
assert len(gpointer_shape) == 2
assert len(nl_atom_numbers_shape) == 1
assert len(nl_atom_serial_shape) == 2
assert len(uint_dr_to_dr_cof_shape) == 1
assert len(excluded_list_start_shape) == 1
assert len(excluded_list_shape) == 1
assert len(excluded_numbers_shape) == 1
assert len(need_refresh_flag_shape) == 1

validator.check_int(atom_numbers_in_grid_bucket_shape[0], self.grid_numbers, Rel.EQ,
"atom_numbers_in_grid_bucket", self.name)
validator.check_int(bucket_shape[0], self.grid_numbers, Rel.EQ, "bucket", self.name)
validator.check_int(bucket_shape[1], self.max_atom_in_grid_numbers, Rel.EQ, "bucket", self.name)
validator.check_int(crd_shape[0], self.atom_numbers, Rel.EQ, "crd", self.name)
validator.check_int(crd_shape[1], 3, Rel.EQ, "crd", self.name)
validator.check_int(box_length_shape[0], 3, Rel.EQ, "box_length", self.name)
validator.check_int(grid_N_shape[0], 3, Rel.EQ, "grid_N", self.name)
validator.check_int(grid_length_inverse_shape[0], 3, Rel.EQ, "grid_length_inverse", self.name)
validator.check_int(atom_in_grid_serial_shape[0], self.atom_numbers, Rel.EQ, "atom_in_grid_serial",
self.name)
validator.check_int(old_crd_shape[0], self.atom_numbers, Rel.EQ, "old_crd", self.name)
validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd", self.name)
validator.check_int(crd_to_uint_crd_cof_shape[0], 3, Rel.EQ, "crd_to_uint_crd_cof", self.name)
validator.check_int(uint_crd_shape[0], self.atom_numbers, Rel.EQ, "uint_crd", self.name)
validator.check_int(uint_crd_shape[1], 3, Rel.EQ, "uint_crd", self.name)
validator.check_int(gpointer_shape[0], self.grid_numbers, Rel.EQ, "gpointer", self.name)
validator.check_int(gpointer_shape[1], 125, Rel.EQ, "gpointer", self.name)
validator.check_int(nl_atom_numbers_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_numbers", self.name)
validator.check_int(nl_atom_serial_shape[0], self.atom_numbers, Rel.EQ, "nl_atom_serial", self.name)
validator.check_int(nl_atom_serial_shape[1], self.max_neighbor_numbers, Rel.EQ, "nl_atom_serial",
self.name)
validator.check_int(uint_dr_to_dr_cof_shape[0], 3, Rel.EQ, "uint_dr_to_dr_cof", self.name)
validator.check_int(excluded_list_start_shape[0], self.atom_numbers, Rel.EQ, "excluded_list_start",
self.name)
validator.check_int(excluded_list_shape[0], self.excluded_atom_numbers, Rel.EQ, "excluded_list",
self.name)
validator.check_int(excluded_numbers_shape[0], self.atom_numbers, Rel.EQ, "excluded_numbers", self.name)
validator.check_int(need_refresh_flag_shape[0], 1, Rel.EQ, "need_refresh_flag", self.name)

return [1,]

def infer_dtype(self, atom_numbers_in_grid_bucket_dtype, bucket_dtype, crd_dtype, box_length_dtype, grid_N_dtype,
grid_length_inverse_dtype, atom_in_grid_serial_dtype, old_crd_dtype, crd_to_uint_crd_cof_dtype,
uint_crd_dtype, gpointer_dtype, nl_atom_numbers_dtype, nl_atom_serial_dtype,
uint_dr_to_dr_cof_dtype, excluded_list_start_dtype, excluded_list_dtype, excluded_numbers_dtype,
need_refresh_flag_dtype, refresh_count_dtype):
validator.check_tensor_dtype_valid('atom_numbers_in_grid_bucket', atom_numbers_in_grid_bucket_dtype,
[mstype.int32], self.name)
validator.check_tensor_dtype_valid('bucket', bucket_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('box_length', box_length_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('grid_N', grid_N_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('grid_length_inverse', grid_length_inverse_dtype, [mstype.float32],
self.name)
validator.check_tensor_dtype_valid('atom_in_grid_serial', atom_in_grid_serial_dtype, [mstype.int32],
self.name)
validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('crd_to_uint_crd_cof', crd_to_uint_crd_cof_dtype, [mstype.float32],
self.name)
validator.check_tensor_dtype_valid('uint_crd', uint_crd_dtype, [mstype.uint32], self.name)
validator.check_tensor_dtype_valid('gpointer', gpointer_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('nl_atom_numbers', nl_atom_numbers_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('nl_atom_serial', nl_atom_serial_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('uint_dr_to_dr_cof', uint_dr_to_dr_cof_dtype, [mstype.float32],
self.name)
validator.check_tensor_dtype_valid('excluded_list_start', excluded_list_start_dtype, [mstype.int32],
self.name)
validator.check_tensor_dtype_valid('excluded_list', excluded_list_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('excluded_numbers', excluded_numbers_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('need_refresh_flag', need_refresh_flag_dtype, [mstype.int32],
self.name)

return mstype.float32

+ 1
- 1
mindspore/ops/operations/sponge_update_ops.py View File

@@ -998,7 +998,7 @@ class MapCenterOfMass(PrimitiveWithInfer):
return mstype.float32


class NeighborListUpdate(PrimitiveWithInfer):
class NeighborListUpdateNew(PrimitiveWithInfer):
"""
Update (or construct if first time) the Verlet neighbor list for the
calculation of short-ranged force. Assume the number of atoms is n,


Loading…
Cancel
Save