| @@ -0,0 +1,95 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/cpu/quantum/evolution_cpu_kernel.h" | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include "utils/ms_utils.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void EvolutionCPUKernel::InitPQCStructure(const CNodePtr &kernel_node) { | |||||
| n_qubits_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, mindquantum::kNQubits); | |||||
| param_names_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::NamesType>(kernel_node, mindquantum::kParamNames); | |||||
| gate_names_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::NamesType>(kernel_node, mindquantum::kGateNames); | |||||
| gate_matrix_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::ComplexMatrixsType>(kernel_node, mindquantum::kGateMatrix); | |||||
| gate_obj_qubits_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::Indexess>(kernel_node, mindquantum::kGateObjQubits); | |||||
| gate_ctrl_qubits_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::Indexess>(kernel_node, mindquantum::kGateCtrlQubits); | |||||
| gate_params_names_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::ParasNameType>(kernel_node, mindquantum::kGateParamsNames); | |||||
| gate_coeff_ = AnfAlgo::GetNodeAttr<mindquantum::transformer::CoeffsType>(kernel_node, mindquantum::kGateCoeff); | |||||
| gate_requires_grad_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::RequiresType>(kernel_node, mindquantum::kGateRequiresGrad); | |||||
| hams_pauli_coeff_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisCoeffsType>(kernel_node, mindquantum::kHamsPauliCoeff); | |||||
| hams_pauli_word_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisWordsType>(kernel_node, mindquantum::kHamsPauliWord); | |||||
| hams_pauli_qubit_ = | |||||
| AnfAlgo::GetNodeAttr<mindquantum::transformer::PaulisQubitsType>(kernel_node, mindquantum::kHamsPauliQubit); | |||||
| } | |||||
| void EvolutionCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::vector<size_t> param_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| std::vector<size_t> result_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||||
| if (param_shape.size() != 1 || result_shape.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "evolution invalid input size"; | |||||
| } | |||||
| state_len_ = result_shape[0]; | |||||
| InitPQCStructure(kernel_node); | |||||
| auto circs = mindquantum::transformer::CircuitTransfor(gate_names_, gate_matrix_, gate_obj_qubits_, gate_ctrl_qubits_, | |||||
| gate_params_names_, gate_coeff_, gate_requires_grad_); | |||||
| circ_ = circs[0]; | |||||
| hams_ = mindquantum::transformer::HamiltoniansTransfor(hams_pauli_coeff_, hams_pauli_word_, hams_pauli_qubit_); | |||||
| if (hams_.size() > 1) { | |||||
| MS_LOG(EXCEPTION) << "evolution only work for single hamiltonian or no hamiltonian."; | |||||
| } | |||||
| } | |||||
| bool EvolutionCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "evolution error input output size!"; | |||||
| } | |||||
| auto param_data = reinterpret_cast<float *>(inputs[0]->addr); | |||||
| auto output = reinterpret_cast<float *>(outputs[0]->addr); | |||||
| MS_EXCEPTION_IF_NULL(param_data); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| sim_ = mindquantum::PQCSimulator(1, n_qubits_); | |||||
| mindquantum::ParameterResolver pr; | |||||
| for (size_t i = 0; i < param_names_.size(); i++) { | |||||
| pr.SetData(param_names_.at(i), param_data[i]); | |||||
| } | |||||
| sim_.Evolution(circ_, pr); | |||||
| if (hams_.size() == 1) { | |||||
| sim_.ApplyHamiltonian(hams_[0]); | |||||
| } | |||||
| if (state_len_ != sim_.vec_.size()) { | |||||
| MS_LOG(EXCEPTION) << "simulation error number of quantum qubit!"; | |||||
| } | |||||
| size_t poi = 0; | |||||
| for (auto &v : sim_.vec_) { | |||||
| output[poi++] = v.real(); | |||||
| output[poi++] = v.imag(); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EVOLUTION_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EVOLUTION_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| #include "mindquantum/pqc_simulator.h" | |||||
| #include "mindquantum/transformer.h" | |||||
| #include "mindquantum/circuit.h" | |||||
| #include "mindquantum/parameter_resolver.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class EvolutionCPUKernel : public CPUKernel { | |||||
| public: | |||||
| EvolutionCPUKernel() = default; | |||||
| ~EvolutionCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| void InitPQCStructure(const CNodePtr &kernel_node); | |||||
| private: | |||||
| int64_t n_qubits_; | |||||
| size_t state_len_; | |||||
| mindquantum::PQCSimulator sim_; | |||||
| mindquantum::BasicCircuit circ_; | |||||
| mindquantum::transformer::Hamiltonians hams_; | |||||
| mindquantum::transformer::NamesType param_names_; | |||||
| // quantum circuit | |||||
| mindquantum::transformer::NamesType gate_names_; | |||||
| mindquantum::transformer::ComplexMatrixsType gate_matrix_; | |||||
| mindquantum::transformer::Indexess gate_obj_qubits_; | |||||
| mindquantum::transformer::Indexess gate_ctrl_qubits_; | |||||
| mindquantum::transformer::ParasNameType gate_params_names_; | |||||
| mindquantum::transformer::CoeffsType gate_coeff_; | |||||
| mindquantum::transformer::RequiresType gate_requires_grad_; | |||||
| // hamiltonian | |||||
| mindquantum::transformer::PaulisCoeffsType hams_pauli_coeff_; | |||||
| mindquantum::transformer::PaulisWordsType hams_pauli_word_; | |||||
| mindquantum::transformer::PaulisQubitsType hams_pauli_qubit_; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(Evolution, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| EvolutionCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EVOLUTION_CPU_KERNEL_H_ | |||||
| @@ -39,6 +39,7 @@ ComplexType ComplexInnerProductWithControl(const Simulator::StateVector &, const | |||||
| std::size_t); | std::size_t); | ||||
| const char kNThreads[] = "n_threads"; | const char kNThreads[] = "n_threads"; | ||||
| const char kNQubits[] = "n_qubits"; | const char kNQubits[] = "n_qubits"; | ||||
| const char kParamNames[] = "param_names"; | |||||
| const char kEncoderParamsNames[] = "encoder_params_names"; | const char kEncoderParamsNames[] = "encoder_params_names"; | ||||
| const char kAnsatzParamsNames[] = "ansatz_params_names"; | const char kAnsatzParamsNames[] = "ansatz_params_names"; | ||||
| const char kGateNames[] = "gate_names"; | const char kGateNames[] = "gate_names"; | ||||
| @@ -98,7 +98,7 @@ from .sparse_ops import SparseToDense | |||||
| from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, | from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, | ||||
| SubAndFilter, | SubAndFilter, | ||||
| MapUniform, DynamicAssign, PadAndShift) | MapUniform, DynamicAssign, PadAndShift) | ||||
| from .quantum_ops import PQC | |||||
| from .quantum_ops import PQC, Evolution | |||||
| from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAtomEnergy, BondForceWithAtomVirial, | from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAtomEnergy, BondForceWithAtomVirial, | ||||
| DihedralForce, DihedralEnergy, DihedralAtomEnergy, DihedralForceWithAtomEnergy, | DihedralForce, DihedralEnergy, DihedralAtomEnergy, DihedralForceWithAtomEnergy, | ||||
| AngleForce, AngleEnergy, AngleAtomEnergy, AngleForceWithAtomEnergy) | AngleForce, AngleEnergy, AngleAtomEnergy, AngleForceWithAtomEnergy) | ||||
| @@ -424,6 +424,7 @@ __all__ = [ | |||||
| "Range", | "Range", | ||||
| "IndexAdd", | "IndexAdd", | ||||
| "PQC", | "PQC", | ||||
| "Evolution", | |||||
| "BondForce", | "BondForce", | ||||
| "BondEnergy", | "BondEnergy", | ||||
| "BondAtomEnergy", | "BondAtomEnergy", | ||||
| @@ -85,3 +85,54 @@ equal to 1, but got {}.".format(len(ansatz_data))) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, | validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, | ||||
| self.name) | self.name) | ||||
| return encoder_data, encoder_data, encoder_data | return encoder_data, encoder_data, encoder_data | ||||
| class Evolution(PrimitiveWithInfer): | |||||
| r""" | |||||
| Inputs of this operation is generated by MindQuantum framework. | |||||
| Inputs: | |||||
| - **n_qubits** (int) - The qubit number of quantum simulator. | |||||
| - **param_names** (List[str]) - The parameters names. | |||||
| - **gate_names** (List[str]) - The name of each gate. | |||||
| - **gate_matrix** (List[List[List[List[float]]]]) - Real part and image | |||||
| part of the matrix of quantum gate. | |||||
| - **gate_obj_qubits** (List[List[int]]) - Object qubits of each gate. | |||||
| - **gate_ctrl_qubits** (List[List[int]]) - Control qubits of each gate. | |||||
| - **gate_params_names** (List[List[str]]) - Parameter names of each gate. | |||||
| - **gate_coeff** (List[List[float]]) - Coefficient of eqch parameter of each gate. | |||||
| - **gate_requires_grad** (List[List[bool]]) - Whether to calculate gradient | |||||
| of parameters of gates. | |||||
| - **hams_pauli_coeff** (List[List[float]]) - Coefficient of pauli words. | |||||
| - **hams_pauli_word** (List[List[List[str]]]) - Pauli words. | |||||
| - **hams_pauli_qubit** (List[List[List[int]]]) - The qubit that pauli matrix act on. | |||||
| Outputs: | |||||
| - **Quantum state** (Tensor) - The quantum state after evolution. | |||||
| Supported Platforms: | |||||
| ``CPU`` | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, n_qubits, param_names, gate_names, gate_matrix, | |||||
| gate_obj_qubits, gate_ctrl_qubits, gate_params_names, | |||||
| gate_coeff, gate_requires_grad, hams_pauli_coeff, | |||||
| hams_pauli_word, hams_pauli_qubit): | |||||
| """Initialize Evolutino""" | |||||
| self.init_prim_io_names(inputs=['param_data'], outputs=['state']) | |||||
| self.n_qubits = n_qubits | |||||
| def check_shape_size(self, param_data): | |||||
| if len(param_data) != 1: | |||||
| raise ValueError("PQC input param_data should have dimension size \ | |||||
| equal to 1, but got {}.".format(len(param_data))) | |||||
| def infer_shape(self, param_data): | |||||
| self.check_shape_size(param_data) | |||||
| return [1 << self.n_qubits, 2] | |||||
| def infer_dtype(self, param_data): | |||||
| args = {'param_data': param_data} | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, | |||||
| self.name) | |||||
| return param_data | |||||
| @@ -2806,6 +2806,30 @@ test_case_quantum_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([[1.0, 2.0, 3.0]]).astype(np.float32)), | ||||
| Tensor(np.array([2.0, 3.0, 4.0]).astype(np.float32))], | Tensor(np.array([2.0, 3.0, 4.0]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('Evolution', { | |||||
| 'block': P.Evolution(n_qubits=3, | |||||
| param_names=['a'], | |||||
| gate_names=['npg', 'npg', 'npg', 'RY'], | |||||
| gate_matrix=[[[['0.7071067811865475', '0.7071067811865475'], | |||||
| ['0.7071067811865475', '-0.7071067811865475']], | |||||
| [['0.0', '0.0'], ['0.0', '0.0']]], | |||||
| [[['0.7071067811865475', '0.7071067811865475'], | |||||
| ['0.7071067811865475', '-0.7071067811865475']], | |||||
| [['0.0', '0.0'], ['0.0', '0.0']]], | |||||
| [[['0.0', '1.0'], ['1.0', '0.0']], | |||||
| [['0.0', '0.0'], ['0.0', '0.0']]], | |||||
| [[['0.0', '0.0'], ['0.0', '0.0']], | |||||
| [['0.0', '0.0'], ['0.0', '0.0']]]], | |||||
| gate_obj_qubits=[[0], [1], [0], [0]], | |||||
| gate_ctrl_qubits=[[], [], [1], []], | |||||
| gate_params_names=[[], [], [], ['a']], | |||||
| gate_coeff=[[], [], [], [1.0]], | |||||
| gate_requires_grad=[[], [], [], [True]], | |||||
| hams_pauli_coeff=[[1.0]], | |||||
| hams_pauli_word=[[['Z']]], | |||||
| hams_pauli_qubit=[[[0]]]), | |||||
| 'desc_inputs': [Tensor(np.array([0.5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, | test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, | ||||