From: @jiahongqian Reviewed-by: @wang_zi_dong,@ljl0711 Signed-off-by: @ljl0711pull/15331/MERGE
| @@ -0,0 +1,123 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| __device__ __host__ float fc(float Rij) { | |||
| const float PI = 3.141592654; | |||
| const float Rc = 1000.0; | |||
| return 0.5 * cosf(PI / Rc * Rij) + 0.5; | |||
| } | |||
| __global__ void Record_Box_Map_Times(int atom_numbers, const float *crd, const float *old_crd, float *box, | |||
| int *box_map_times) { | |||
| float half_box[3] = {0.5 * box[0], 0.5 * box[1], 0.5 * box[2]}; | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i < atom_numbers) { | |||
| if (crd[3 * i + 0] - old_crd[3 * i + 0] > half_box[0]) { | |||
| box_map_times[3 * i + 0] = box_map_times[3 * i + 0] - 1; | |||
| } else if (crd[3 * i + 0] - old_crd[3 * i + 0] < -half_box[0]) { | |||
| box_map_times[3 * i + 0] = box_map_times[3 * i + 0] + 1; | |||
| } | |||
| if (crd[3 * i + 1] - old_crd[3 * i + 1] > half_box[1]) { | |||
| box_map_times[3 * i + 1] = box_map_times[3 * i + 1] - 1; | |||
| } else if (crd[3 * i + 1] - old_crd[3 * i + 1] < -half_box[1]) { | |||
| box_map_times[3 * i + 1] = box_map_times[3 * i + 1] + 1; | |||
| } | |||
| if (crd[3 * i + 2] - old_crd[3 * i + 2] > half_box[2]) { | |||
| box_map_times[3 * i + 2] = box_map_times[3 * i + 2] - 1; | |||
| } else if (crd[3 * i + 2] - old_crd[3 * i + 2] < -half_box[2]) { | |||
| box_map_times[3 * i + 2] = box_map_times[3 * i + 2] + 1; | |||
| } | |||
| } | |||
| } | |||
| __global__ void gen_nowarp_crd(int atom_numbers, const float *crd, float *box, int *box_map_times, float *nowarp_crd) { | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i < atom_numbers) { | |||
| nowarp_crd[3 * i + 0] = static_cast<float>(box_map_times[3 * i + 0]) * box[0] + crd[3 * i + 0]; | |||
| nowarp_crd[3 * i + 1] = static_cast<float>(box_map_times[3 * i + 1]) * box[1] + crd[3 * i + 1]; | |||
| nowarp_crd[3 * i + 2] = static_cast<float>(box_map_times[3 * i + 2]) * box[2] + crd[3 * i + 2]; | |||
| } | |||
| } | |||
| __global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) { | |||
| const float Rs = 0.5, Eta = 0.5; | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i >= start_serial && i < end_serial) { | |||
| float rij; | |||
| float g_radial_lin = 0.; | |||
| for (int j = start_serial; j < end_serial; j = j + 1) { | |||
| if (j != i) { | |||
| // rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j])); | |||
| rij = sqrtf(normfloat(crd, crd, i, j)); | |||
| g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij); | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| g_radial[i] = g_radial_lin; | |||
| } | |||
| } | |||
| __global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) { | |||
| const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0; | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i >= start_serial && i < end_serial) { | |||
| float rij, rik, rjk, theta_jik; | |||
| float g_angular_lin = 0.; | |||
| for (int j = start_serial; j < end_serial; j = j + 1) { | |||
| if (j != i) { | |||
| rij = sqrtf(normfloat(crd, crd, i, j)); | |||
| for (int k = j + 1; k < end_serial; k = k + 1) { | |||
| if (k != i) { | |||
| rik = sqrtf(normfloat(crd, crd, i, k)); | |||
| rjk = sqrtf(normfloat(crd, crd, j, k)); | |||
| theta_jik = | |||
| acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999)); | |||
| g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) * | |||
| expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik); | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin; | |||
| } | |||
| } | |||
| void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, | |||
| const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, | |||
| float *g_angular, cudaStream_t stream) { | |||
| Reset_List<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(3 * atom_numbers, box_map_times, | |||
| 0); | |||
| Record_Box_Map_Times<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>( | |||
| atom_numbers, crd_f, old_crd, box, box_map_times); | |||
| gen_nowarp_crd<<<ceilf(static_cast<float>(3. * atom_numbers) / 128), 128, 0, stream>>>(atom_numbers, crd_f, box, | |||
| box_map_times, nowarp_crd); | |||
| G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_radial); | |||
| G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, nowarp_crd, g_angular); | |||
| return; | |||
| } | |||
| void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, | |||
| const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, | |||
| float *g_angular, cudaStream_t stream); | |||
| @@ -14,12 +14,13 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, | |||
| cudaStream_t stream); | |||
| void AtomCrdToCV(int atom_numbers, int start_serial, int end_serial, int number, const float *crd_f, | |||
| const float *old_crd, float *nowarp_crd, int *box_map_times, float *box, float *g_radial, | |||
| float *g_angular, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_TRANSFER_IMPL_H_ | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_COMMON_ATOMCRDTOCV_IMPL_H_ | |||
| @@ -1,83 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common_sponge.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| __device__ __host__ float fc(float Rij) { | |||
| const float PI = 3.141592654; | |||
| const float Rc = 1000.0; | |||
| return 0.5 * cosf(PI / Rc * Rij) + 0.5; | |||
| } | |||
| __global__ void G_Radial(const int start_serial, const int end_serial, const float *crd, float *g_radial) { | |||
| const float Rs = 0.5, Eta = 0.5; | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i >= start_serial && i < end_serial) { | |||
| float rij; | |||
| float g_radial_lin = 0.; | |||
| for (int j = start_serial; j < end_serial; j = j + 1) { | |||
| if (j != i) { | |||
| // rij = sqrtf((crd[3*i+0] - crd[j]) * (crd[i] - crd[j])); | |||
| rij = sqrtf(normfloat(crd, crd, i, j)); | |||
| g_radial_lin = g_radial_lin + expf(-Eta * (rij - Rs) * (rij - Rs)) * fc(rij); | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| g_radial[i] = g_radial_lin; | |||
| } | |||
| } | |||
| __global__ void G_Angular(const int start_serial, const int end_serial, const float *crd, float *g_angular) { | |||
| const float Rs = 0.5, Thetas = 3.14, Eta = 0.5, Zeta = 2.0; | |||
| int i = blockDim.x * blockIdx.x + threadIdx.x; | |||
| if (i >= start_serial && i < end_serial) { | |||
| float rij, rik, rjk, theta_jik; | |||
| float g_angular_lin = 0.; | |||
| for (int j = start_serial; j < end_serial; j = j + 1) { | |||
| if (j != i) { | |||
| rij = sqrtf(normfloat(crd, crd, i, j)); | |||
| for (int k = j + 1; k < end_serial; k = k + 1) { | |||
| if (k != i) { | |||
| rik = sqrtf(normfloat(crd, crd, i, k)); | |||
| rjk = sqrtf(normfloat(crd, crd, j, k)); | |||
| theta_jik = | |||
| acosf(fmaxf(fminf((rij * rij + rik * rik - rjk * rjk) / (2. * rij * rik), 0.999999), -0.999999)); | |||
| g_angular_lin = g_angular_lin + powf(1. + cosf(theta_jik - Thetas), Zeta) * | |||
| expf(-Eta * powf(0.5 * (rij + rik) - Rs, 2.)) * fc(rij) * fc(rik); | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| g_angular[i] = powf(2., 1. - Zeta) * g_angular_lin; | |||
| } | |||
| } | |||
| void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, | |||
| cudaStream_t stream) { | |||
| G_Radial<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_radial); | |||
| G_Angular<<<1, number, 0, stream>>>(start_serial, end_serial, crd_f, g_angular); | |||
| return; | |||
| } | |||
| void Transfer(int start_serial, int end_serial, int number, const float *crd_f, float *g_radial, float *g_angular, | |||
| cudaStream_t stream); | |||
| @@ -14,13 +14,19 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/sponge/common/transfer_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/sponge/common/atomcrdtocv_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| TransferCrd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| TransferGpuKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO(TransferCrd, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| AtomCrdToCVGpuKernel, float, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -14,10 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/transfer_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sponge/common/atomcrdtocv_impl.cuh" | |||
| #include <cuda_runtime_api.h> | |||
| #include <map> | |||
| @@ -31,19 +31,24 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename T1> | |||
| class TransferGpuKernel : public GpuKernel { | |||
| class AtomCrdToCVGpuKernel : public GpuKernel { | |||
| public: | |||
| TransferGpuKernel() : ele_crd(1) {} | |||
| ~TransferGpuKernel() override = default; | |||
| AtomCrdToCVGpuKernel() : ele_crd(1) {} | |||
| ~AtomCrdToCVGpuKernel() override = default; | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| start_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "start_serial")); | |||
| end_serial = static_cast<int>(GetAttr<int64_t>(kernel_node, "end_serial")); | |||
| number = static_cast<int>(GetAttr<int64_t>(kernel_node, "number")); | |||
| atom_numbers = static_cast<int>(GetAttr<int64_t>(kernel_node, "atom_numbers")); | |||
| auto shape_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto shape_old_crd = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto shape_box = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < shape_crd.size(); i++) ele_crd *= shape_crd[i]; | |||
| for (size_t i = 0; i < shape_old_crd.size(); i++) ele_old_crd *= shape_old_crd[i]; | |||
| for (size_t i = 0; i < shape_box.size(); i++) ele_box *= shape_box[i]; | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -53,27 +58,40 @@ class TransferGpuKernel : public GpuKernel { | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| auto crd = GetDeviceAddress<const T>(inputs, 0); | |||
| auto old_crd = GetDeviceAddress<const T>(inputs, 1); | |||
| auto box = GetDeviceAddress<T>(inputs, 2); | |||
| auto g_radial = GetDeviceAddress<T>(outputs, 0); | |||
| auto g_angular = GetDeviceAddress<T>(outputs, 1); | |||
| Transfer(start_serial, end_serial, number, crd, g_radial, g_angular, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| auto nowarp_crd = GetDeviceAddress<T>(outputs, 2); | |||
| auto box_map_times = GetDeviceAddress<T1>(outputs, 3); | |||
| AtomCrdToCV(atom_numbers, start_serial, end_serial, number, crd, old_crd, nowarp_crd, box_map_times, box, g_radial, | |||
| g_angular, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(ele_crd * sizeof(T)); | |||
| input_size_list_.push_back(ele_old_crd * sizeof(T)); | |||
| input_size_list_.push_back(ele_box * sizeof(T)); | |||
| output_size_list_.push_back(number * sizeof(T)); | |||
| output_size_list_.push_back(number * sizeof(T)); | |||
| output_size_list_.push_back(3 * atom_numbers * sizeof(T)); | |||
| output_size_list_.push_back(3 * atom_numbers * sizeof(T1)); | |||
| } | |||
| private: | |||
| size_t ele_crd = 1; | |||
| size_t ele_old_crd = 1; | |||
| size_t ele_box = 1; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -81,7 +99,8 @@ class TransferGpuKernel : public GpuKernel { | |||
| int end_serial; | |||
| int start_serial; | |||
| int number; | |||
| int atom_numbers; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_TRANSFER_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONG_COMMON_ATOMCRDTOCV_KERNEL_H_ | |||
| @@ -2852,33 +2852,54 @@ class TransferCrd(PrimitiveWithInfer): | |||
| Inputs: | |||
| - **crd** (Tensor, float32) - [N, 3], the coordinate of each atom. | |||
| N is the number of atoms.. | |||
| - **old_crd** (Tensor, float32) - [N, 3], the last coordinate of each atom. | |||
| N is the number of atoms. | |||
| - **box** (Tensor, float32) - [3,], the length of 3 dimensions of the simulation box. | |||
| Outputs: | |||
| - **output** (uint32) | |||
| - **radial** (Tensor, float32) - [number,], the array of radial transferred from coordinates. | |||
| - **angular** (Tensor, float32) - [number,], the array of angular transferred from coordinates. | |||
| - **nowarp_crd** (Tensor, float32) - [N, 3], the modified coordinate of each atom for | |||
| computing radial and angular. | |||
| - **box_map_times** (Tensor, int32) - [N, 3], the box map times for radial and angular. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, start_serial, end_serial, number): | |||
| def __init__(self, start_serial, end_serial, number, atom_numbers): | |||
| validator.check_value_type('start_serial', start_serial, (int), self.name) | |||
| validator.check_value_type('end_serial', end_serial, (int), self.name) | |||
| validator.check_value_type('number', number, (int), self.name) | |||
| validator.check_value_type('atom_numbers', atom_numbers, (int), self.name) | |||
| self.start_serial = start_serial | |||
| self.end_serial = end_serial | |||
| self.number = number | |||
| self.atom_numbers = atom_numbers | |||
| self.add_prim_attr('start_serial', self.start_serial) | |||
| self.add_prim_attr('end_serial', self.end_serial) | |||
| self.add_prim_attr('number', self.number) | |||
| self.add_prim_attr('atom_numbers', self.atom_numbers) | |||
| self.init_prim_io_names( | |||
| inputs=['crd'], | |||
| outputs=['radial', 'angular']) | |||
| inputs=['crd', 'old_crd', 'box'], | |||
| outputs=['radial', 'angular', 'nowarp_crd', 'box_map_times']) | |||
| def infer_shape(self, crd_shape): | |||
| def infer_shape(self, crd_shape, old_crd_shape, box_shape): | |||
| N = self.atom_numbers | |||
| validator.check_int(len(crd_shape), 2, Rel.EQ, "crd_dim", self.name) | |||
| validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[0]", self.name) | |||
| return [self.number,], [self.number,] | |||
| def infer_dtype(self, crd_dtype): | |||
| validator.check_int(crd_shape[0], N, Rel.EQ, "crd_shape[0]", self.name) | |||
| validator.check_int(crd_shape[1], 3, Rel.EQ, "crd_shape[1]", self.name) | |||
| validator.check_int(len(old_crd_shape), 2, Rel.EQ, "old_crd_dim", self.name) | |||
| validator.check_int(old_crd_shape[0], N, Rel.EQ, "old_crd_shape[0]", self.name) | |||
| validator.check_int(old_crd_shape[1], 3, Rel.EQ, "old_crd_shape[1]", self.name) | |||
| validator.check_int(len(box_shape), 1, Rel.EQ, "box_dim", self.name) | |||
| validator.check_int(box_shape[0], 3, Rel.EQ, "box_shape[0]", self.name) | |||
| return [self.number,], [self.number,], [N, 3], [N, 3] | |||
| def infer_dtype(self, crd_dtype, old_crd_dtype, box_dtype): | |||
| validator.check_tensor_dtype_valid('crd', crd_dtype, [mstype.float32], self.name) | |||
| return mstype.float32, mstype.float32 | |||
| validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name) | |||
| return mstype.float32, mstype.float32, mstype.float32, mstype.int32 | |||
| @@ -17,73 +17,57 @@ import argparse | |||
| import time | |||
| from src.simulation import Simulation | |||
| from src.mdnn import Mdnn, TransCrdToCV | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore import load_checkpoint | |||
| parser = argparse.ArgumentParser(description='Sponge Controller') | |||
| parser.add_argument('--i', type=str, default=None, help='input file') | |||
| parser.add_argument('--amber_parm', type=str, default=None, help='paramter file in AMBER type') | |||
| parser.add_argument('--c', type=str, default=None, help='initial coordinates file') | |||
| parser = argparse.ArgumentParser(description='SPONGE Controller') | |||
| parser.add_argument('--i', type=str, default=None, help='Input file') | |||
| parser.add_argument('--amber_parm', type=str, default=None, help='Paramter file in AMBER type') | |||
| parser.add_argument('--c', type=str, default=None, help='Initial coordinates file') | |||
| parser.add_argument('--r', type=str, default="restrt", help='') | |||
| parser.add_argument('--x', type=str, default="mdcrd", help='') | |||
| parser.add_argument('--o', type=str, default="mdout", help="") | |||
| parser.add_argument('--o', type=str, default="mdout", help='Output file') | |||
| parser.add_argument('--box', type=str, default="mdbox", help='') | |||
| parser.add_argument('--device_id', type=int, default=0, help='') | |||
| parser.add_argument('--device_id', type=int, default=0, help='GPU device id') | |||
| parser.add_argument('--u', type=bool, default=False, help='If use mdnn to update the atom charge') | |||
| parser.add_argument('--checkpoint', type=str, default="", help='Checkpoint file') | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False) | |||
| if __name__ == "__main__": | |||
| simulation = Simulation(args_opt) | |||
| if args_opt.u and args_opt.checkpoint: | |||
| net = Mdnn() | |||
| load_checkpoint(args_opt.checkpoint, net=net) | |||
| transcrd = TransCrdToCV(simulation) | |||
| start = time.time() | |||
| compiler_time = 0 | |||
| save_path = args_opt.o | |||
| file = open(save_path, 'w') | |||
| simulation.Main_Initial() | |||
| for steps in range(simulation.md_info.step_limit): | |||
| print_step = steps % simulation.ntwx | |||
| if steps == simulation.md_info.step_limit - 1: | |||
| print_step = 0 | |||
| temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \ | |||
| nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene, _ = simulation(Tensor(steps), Tensor(print_step)) | |||
| if steps == 0: | |||
| compiler_time = time.time() | |||
| if steps % simulation.ntwx == 0 or steps == simulation.md_info.step_limit - 1: | |||
| if steps == 0: | |||
| print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " | |||
| "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_") | |||
| file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " | |||
| "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n") | |||
| simulation.Main_Print(steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, | |||
| sigma_of_dihedral_ene, nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene) | |||
| temperature = temperature.asnumpy() | |||
| total_potential_energy = total_potential_energy.asnumpy() | |||
| print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), | |||
| end=" ") | |||
| if simulation.bond.bond_numbers > 0: | |||
| sigma_of_bond_ene = sigma_of_bond_ene.asnumpy() | |||
| print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ") | |||
| if simulation.angle.angle_numbers > 0: | |||
| sigma_of_angle_ene = sigma_of_angle_ene.asnumpy() | |||
| print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ") | |||
| if simulation.dihedral.dihedral_numbers > 0: | |||
| sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy() | |||
| print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ") | |||
| if simulation.nb14.nb14_numbers > 0: | |||
| nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy() | |||
| nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy() | |||
| print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ") | |||
| LJ_energy_sum = LJ_energy_sum.asnumpy() | |||
| ee_ene = ee_ene.asnumpy() | |||
| print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ") | |||
| print("{:>12.3f}".format(float(ee_ene))) | |||
| if file is not None: | |||
| file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}" | |||
| " {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy), | |||
| float(sigma_of_bond_ene), float(sigma_of_angle_ene), | |||
| float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum), | |||
| float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene))) | |||
| if args_opt.u and args_opt.checkpoint and steps % (4 * simulation.ntwx) == 0: | |||
| print("Update charge!") | |||
| inputs = transcrd(Tensor(simulation.crd), Tensor(simulation.last_crd)) | |||
| t_charge = net(inputs) | |||
| simulation.charge = transcrd.updatecharge(t_charge) | |||
| end = time.time() | |||
| file.close() | |||
| print("Main time(s):", end - start) | |||
| print("Main time(s) without compiler:", end - compiler_time) | |||
| simulation.Main_Destroy() | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """mdnn class""" | |||
| import numpy as np | |||
| from mindspore import nn, Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.parameter import Parameter | |||
| import mindspore.common.dtype as mstype | |||
| class Mdnn(nn.Cell): | |||
| """Mdnn""" | |||
| def __init__(self, dim=258, dr=0.5): | |||
| super(Mdnn, self).__init__() | |||
| self.dim = dim | |||
| self.dr = dr # dropout_ratio | |||
| self.fc1 = nn.Dense(dim, 512) | |||
| self.fc2 = nn.Dense(512, 512) | |||
| self.fc3 = nn.Dense(512, 512) | |||
| self.fc4 = nn.Dense(512, 129) | |||
| self.tanh = nn.Tanh() | |||
| def construct(self, x): | |||
| """construct""" | |||
| x = self.tanh(self.fc1(x)) | |||
| x = self.tanh(self.fc2(x)) | |||
| x = self.tanh(self.fc3(x)) | |||
| x = self.fc4(x) | |||
| return x | |||
| class TransCrdToCV(nn.Cell): | |||
| """TransCrdToCV""" | |||
| def __init__(self, simulation): | |||
| super(TransCrdToCV, self).__init__() | |||
| self.atom_numbers = simulation.atom_numbers | |||
| self.transfercrd = P.TransferCrd(0, 129, 129, self.atom_numbers) | |||
| self.box = Tensor(simulation.box_length) | |||
| self.radial = Parameter(Tensor(np.zeros([129,]), mstype.float32)) | |||
| self.angular = Parameter(Tensor(np.zeros([129,]), mstype.float32)) | |||
| self.output = Parameter(Tensor(np.zeros([1, 258]), mstype.float32)) | |||
| self.charge = simulation.charge | |||
| def updatecharge(self, t_charge): | |||
| """update charge in simulation""" | |||
| self.charge[:129] = t_charge[0] * 18.2223 | |||
| return self.charge | |||
| def construct(self, crd, last_crd): | |||
| """construct""" | |||
| self.radial, self.angular, _, _ = self.transfercrd(crd, last_crd, self.box) | |||
| self.output = P.Concat()((self.radial, self.angular)) | |||
| self.output = P.ExpandDims()(self.output, 0) | |||
| return self.output | |||
| @@ -34,6 +34,7 @@ from src.particle_mesh_ewald import Particle_Mesh_Ewald | |||
| class controller: | |||
| '''controller''' | |||
| def __init__(self, args_opt): | |||
| self.input_file = args_opt.i | |||
| self.initial_coordinates_file = args_opt.c | |||
| @@ -67,6 +68,7 @@ class controller: | |||
| class Simulation(nn.Cell): | |||
| '''simulation''' | |||
| def __init__(self, args_opt): | |||
| super(Simulation, self).__init__() | |||
| self.control = controller(args_opt) | |||
| @@ -119,6 +121,7 @@ class Simulation(nn.Cell): | |||
| self.exp_gamma = self.liujian_info.exp_gamma | |||
| self.init_Tensor() | |||
| self.op_define() | |||
| self.update = False | |||
| def init_Tensor(self): | |||
| '''init tensor''' | |||
| @@ -129,9 +132,12 @@ class Simulation(nn.Cell): | |||
| self.uint_dr_to_dr_cof = Parameter( | |||
| Tensor(np.asarray(self.md_info.uint_dr_to_dr_cof, np.float32), mstype.float32), requires_grad=False) | |||
| self.box_length = Tensor(self.md_info.box_length, mstype.float32) | |||
| self.charge = Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32) | |||
| self.charge = Parameter(Tensor(np.asarray(self.md_info.h_charge, dtype=np.float32), mstype.float32), | |||
| requires_grad=False) | |||
| self.old_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), | |||
| requires_grad=False) | |||
| self.last_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.float32), mstype.float32), | |||
| requires_grad=False) | |||
| self.uint_crd = Parameter(Tensor(np.zeros([self.atom_numbers, 3], dtype=np.uint32), mstype.uint32), | |||
| requires_grad=False) | |||
| self.mass_inverse = Tensor(self.md_info.h_mass_inverse, mstype.float32) | |||
| @@ -341,8 +347,65 @@ class Simulation(nn.Cell): | |||
| acc = F.depend(self.acc, crd) | |||
| return vel, crd, acc | |||
| def Main_Print(self, *args): | |||
| """compute the temperature""" | |||
| steps, temperature, total_potential_energy, sigma_of_bond_ene, sigma_of_angle_ene, sigma_of_dihedral_ene, \ | |||
| nb14_lj_energy_sum, nb14_cf_energy_sum, LJ_energy_sum, ee_ene = list(args) | |||
| if steps == 0: | |||
| print("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " | |||
| "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_") | |||
| temperature = temperature.asnumpy() | |||
| total_potential_energy = total_potential_energy.asnumpy() | |||
| print("{:>7.0f} {:>7.3f} {:>11.3f}".format(steps, float(temperature), float(total_potential_energy)), | |||
| end=" ") | |||
| if self.bond.bond_numbers > 0: | |||
| sigma_of_bond_ene = sigma_of_bond_ene.asnumpy() | |||
| print("{:>10.3f}".format(float(sigma_of_bond_ene)), end=" ") | |||
| if self.angle.angle_numbers > 0: | |||
| sigma_of_angle_ene = sigma_of_angle_ene.asnumpy() | |||
| print("{:>11.3f}".format(float(sigma_of_angle_ene)), end=" ") | |||
| if self.dihedral.dihedral_numbers > 0: | |||
| sigma_of_dihedral_ene = sigma_of_dihedral_ene.asnumpy() | |||
| print("{:>14.3f}".format(float(sigma_of_dihedral_ene)), end=" ") | |||
| if self.nb14.nb14_numbers > 0: | |||
| nb14_lj_energy_sum = nb14_lj_energy_sum.asnumpy() | |||
| nb14_cf_energy_sum = nb14_cf_energy_sum.asnumpy() | |||
| print("{:>10.3f} {:>10.3f}".format(float(nb14_lj_energy_sum), float(nb14_cf_energy_sum)), end=" ") | |||
| LJ_energy_sum = LJ_energy_sum.asnumpy() | |||
| ee_ene = ee_ene.asnumpy() | |||
| print("{:>7.3f}".format(float(LJ_energy_sum)), end=" ") | |||
| print("{:>12.3f}".format(float(ee_ene))) | |||
| if self.file is not None: | |||
| self.file.write("{:>7.0f} {:>7.3f} {:>11.3f} {:>10.3f} {:>11.3f} {:>14.3f} {:>10.3f} {:>10.3f} {:>7.3f}" | |||
| " {:>12.3f}\n".format(steps, float(temperature), float(total_potential_energy), | |||
| float(sigma_of_bond_ene), float(sigma_of_angle_ene), | |||
| float(sigma_of_dihedral_ene), float(nb14_lj_energy_sum), | |||
| float(nb14_cf_energy_sum), float(LJ_energy_sum), float(ee_ene))) | |||
| if self.datfile is not None: | |||
| self.datfile.write(self.crd.asnumpy()) | |||
| def Main_Initial(self): | |||
| """main initial""" | |||
| if self.control.mdout: | |||
| self.file = open(self.control.mdout, 'w') | |||
| self.file.write("_steps_ _TEMP_ _TOT_POT_ENE_ _BOND_ENE_ " | |||
| "_ANGLE_ENE_ _DIHEDRAL_ENE_ _14LJ_ENE_ _14CF_ENE_ _LJ_ENE_ _CF_PME_ENE_\n") | |||
| if self.control.mdcrd: | |||
| self.datfile = open(self.control.mdcrd, 'wb') | |||
| def Main_Destroy(self): | |||
| """main destroy""" | |||
| if self.file is not None: | |||
| self.file.close() | |||
| print("Save .out file successfully!") | |||
| if self.datfile is not None: | |||
| self.datfile.close() | |||
| print("Save .dat file successfully!") | |||
| def construct(self, step, print_step): | |||
| '''construct''' | |||
| self.last_crd = self.crd | |||
| if step == 0: | |||
| res = self.neighbor_list_update_init(self.atom_numbers_in_grid_bucket, self.bucket, self.crd, | |||
| self.box_length, self.grid_N, self.grid_length_inverse, | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train""" | |||
| import argparse | |||
| import numpy as np | |||
| from src.mdnn import Mdnn | |||
| from mindspore import nn, Model, context | |||
| from mindspore import dataset as ds | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||
| from mindspore.train.callback import Callback | |||
| import mindspore.common.initializer as weight_init | |||
| parser = argparse.ArgumentParser(description='Mdnn Controller') | |||
| parser.add_argument('--i', type=str, default=None, help='Input radial and angular dat file') | |||
| parser.add_argument('--charge', type=str, default=None, help='Input charge dat file') | |||
| parser.add_argument('--device_id', type=int, default=0, help='GPU device id') | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False) | |||
| class StepLossAccInfo(Callback): | |||
| """custom callback function""" | |||
| def __init__(self, models, eval_dataset, steploss): | |||
| """init model""" | |||
| self.model = models | |||
| self.eval_dataset = eval_dataset | |||
| self.steps_loss = steploss | |||
| def step_end(self, run_context): | |||
| """step end""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num | |||
| self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) | |||
| self.steps_loss["step"].append(str(cur_step)) | |||
| def get_data(inputdata, outputdata): | |||
| """get data function""" | |||
| for _, data in enumerate(zip(inputdata, outputdata)): | |||
| yield data | |||
| def create_dataset(inputdata, outputdata, batchsize=32, repeat_size=1): | |||
| """create dataset function""" | |||
| input_data = ds.GeneratorDataset(list(get_data(inputdata, outputdata)), column_names=['data', 'label']) | |||
| input_data = input_data.batch(batchsize) | |||
| input_data = input_data.repeat(repeat_size) | |||
| return input_data | |||
| def init_weight(nnet): | |||
| for _, cell in nnet.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), | |||
| cell.weight.shape, | |||
| cell.weight.dtype)) | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), | |||
| cell.weight.shape, | |||
| cell.weight.dtype)) | |||
| if __name__ == '__main__': | |||
| # read input files | |||
| inputs = args_opt.i | |||
| outputs = args_opt.charge | |||
| radial_angular = np.fromfile(inputs, dtype=np.float32) | |||
| radial_angular = radial_angular.reshape((-1, 258)).astype(np.float32) | |||
| charge = np.fromfile(outputs, dtype=np.float32) | |||
| charge = charge.reshape((-1, 129)).astype(np.float32) | |||
| # define the model | |||
| net = Mdnn() | |||
| lr = 0.0001 | |||
| decay_rate = 0.8 | |||
| epoch_size = 1000 | |||
| batch_size = 500 | |||
| total_step = epoch_size * batch_size | |||
| step_per_epoch = 100 | |||
| decay_epoch = epoch_size | |||
| lr_rate = nn.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| net_loss = nn.loss.MSELoss(reduction='mean') | |||
| net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_rate) | |||
| model = Model(net, net_loss, net_opt) | |||
| ds_train = create_dataset(radial_angular, charge, batchsize=batch_size) | |||
| model_params = net.trainable_params() | |||
| net.set_train() | |||
| init_weight(net) | |||
| # config files | |||
| path = './params/' | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10) | |||
| ckpoint_cb = ModelCheckpoint(prefix="mdnn_best", directory=path, config=config_ck) | |||
| steps_loss = {"step": [], "loss_value": []} | |||
| step_loss_acc_info = StepLossAccInfo(model, ds_train, steps_loss) | |||
| # train the model | |||
| model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(100)]) | |||