| @@ -20,6 +20,45 @@ | |||
| #include "pybind_api/api_register.h" | |||
| namespace mindspore { | |||
| py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict ¶ms_init) { | |||
| py::dict hyper_params; | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| auto param_node = param->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| py::str param_name = py::str(param_node->name()); | |||
| if (param_node->has_default()) { | |||
| const char kModelName[] = "mindspore"; | |||
| const char kClassName[] = "Parameter"; | |||
| const py::module &mod = py::module::import(kModelName); | |||
| const py::object &fn = mod.attr(kClassName); | |||
| const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(old_value); | |||
| py::object new_param; | |||
| if (params_init.contains(param_name)) { | |||
| const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_value); | |||
| if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) { | |||
| MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. " | |||
| "The parameter '" | |||
| << param_name.cast<std::string>() << "' has shape " << old_value->shape() | |||
| << " and dtype " << TypeIdLabel(old_value->data_type()) | |||
| << ", but got the update Tensor with shape " << new_value->shape() << " and dtype " | |||
| << TypeIdLabel(new_value->data_type()) << "."; | |||
| } | |||
| new_param = fn(*new_value); | |||
| } else { | |||
| new_param = fn(*old_value); | |||
| } | |||
| auto new_default_param = new_param.cast<tensor::TensorPtr>(); | |||
| new_default_param->set_param_info(old_value->param_info()); | |||
| param_node->set_default_param(new_default_param); | |||
| hyper_params[param_name] = new_param; | |||
| } | |||
| } | |||
| return hyper_params; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | |||
| // Define python "MetaFuncGraph_" class | |||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | |||
| @@ -28,8 +67,11 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | |||
| (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | |||
| .def(py::init()) | |||
| .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") | |||
| .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph") | |||
| .def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"), | |||
| "Update FuncGraph hyper parameters, and return the updated parameters."); | |||
| .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(_c_expression, ([](pybind11::module *const m) { | |||
| (void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams, | |||
| py::arg("func_graph"), py::arg("params_init"), | |||
| "Update FuncGraph hyper parameters, and return the updated parameters."); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -556,38 +556,6 @@ size_t FuncGraph::GetDefaultValueCount() { | |||
| return parameter_default_value_.size() - LongToSize(null_count); | |||
| } | |||
| std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams( | |||
| const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init) { | |||
| std::map<std::string, ValuePtr> hyper_params; | |||
| for (const auto ¶ : parameters_) { | |||
| auto param_node = para->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| const std::string ¶m_name = param_node->name(); | |||
| if (param_node->has_default()) { | |||
| if (params_init.find(param_name) != params_init.end()) { | |||
| const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>(); | |||
| const auto &new_value = params_init.at(param_name); | |||
| MS_EXCEPTION_IF_NULL(old_value); | |||
| MS_EXCEPTION_IF_NULL(new_value); | |||
| if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) { | |||
| MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. " | |||
| "The parameter '" | |||
| << param_name << "' has shape " << old_value->shape() << " and dtype " | |||
| << TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape " | |||
| << new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << "."; | |||
| } | |||
| auto new_default_param = std::make_shared<tensor::Tensor>(*new_value); | |||
| new_default_param->set_param_info(old_value->param_info()); | |||
| param_node->set_default_param(new_default_param); | |||
| } | |||
| hyper_params[param_name] = param_node->default_param(); | |||
| } | |||
| } | |||
| return hyper_params; | |||
| } | |||
| AnfNodePtr FuncGraph::GetVariableArgParameter() { | |||
| if (!has_vararg_) { | |||
| return nullptr; | |||
| @@ -215,8 +215,6 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi | |||
| void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list); | |||
| void ClearDefaultValues(); | |||
| size_t GetDefaultValueCount(); | |||
| std::map<std::string, ValuePtr> UpdateHyperParams( | |||
| const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init); | |||
| std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; } | |||
| void set_has_vararg(bool has_) { has_vararg_ = has_; } | |||
| bool has_vararg() const { return has_vararg_; } | |||
| @@ -26,7 +26,7 @@ from mindspore import log as logger | |||
| from mindspore.common.parameter import PARAMETER_NAME_DEFAULT | |||
| from mindspore.context import ParallelMode | |||
| from .. import context | |||
| from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType | |||
| from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType | |||
| from .._checkparam import Validator | |||
| from ..common import dtype as mstype | |||
| from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor | |||
| @@ -1703,10 +1703,8 @@ class GraphCell(Cell): | |||
| raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, " | |||
| f"but got the key type: {type(name)}, and the value type: {type(value)}") | |||
| params_dict = self.graph.update_hyper_params(params_init) | |||
| for name, value in params_dict.items(): | |||
| param = Parameter(value) | |||
| param.param_info = value.param_info | |||
| params_dict = update_func_graph_hyper_params(self.graph, params_init) | |||
| for name, param in params_dict.items(): | |||
| self._params[name] = param | |||
| def construct(self, *inputs): | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test get and init GraphCell parameters""" | |||
| import os | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore import export, load, save_checkpoint, load_checkpoint | |||
| from mindspore import nn | |||
| class TestNet(nn.Cell): | |||
| def __init__(self): | |||
| super(TestNet, self).__init__() | |||
| self.flag = False | |||
| self.weight = Parameter(np_param, requires_grad=True) | |||
| self.dense = nn.Dense(3, 4) | |||
| def construct(self, x, y): | |||
| if self.flag: | |||
| ret = self.dense(x * self.weight) | |||
| else: | |||
| ret = x * y * self.weight | |||
| self.weight += 1.0 | |||
| return ret | |||
| np_a = np.ones((2, 3), np.float32) + 2 | |||
| np_b = np.ones((2, 3), np.float32) + 3 | |||
| np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) | |||
| input_a = Tensor(np_a) | |||
| input_b = Tensor(np_b) | |||
| def load_mindir_and_update_params(mindir_name, ckpt_name): | |||
| net = TestNet() | |||
| export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') | |||
| load_net = nn.GraphCell(graph=load(mindir_name)) | |||
| ret = load_net(input_a, input_b) | |||
| save_checkpoint(load_net, ckpt_name) | |||
| assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param) | |||
| assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 1.0) | |||
| params_init = load_checkpoint(ckpt_name) | |||
| load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init) | |||
| return load_net_with_new_params | |||
| def get_and_init_graph_cell_parameters(): | |||
| mindir_name = f"{context.get_context('mode')}_test_graph_cell_net.mindir" | |||
| ckpt_name = f"{context.get_context('mode')}_test_graph_cell_net.ckpt" | |||
| load_net = load_mindir_and_update_params(mindir_name, ckpt_name) | |||
| ret = load_net(input_a, input_b) | |||
| assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0)) | |||
| assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0) | |||
| if os.path.isfile(mindir_name): | |||
| os.remove(mindir_name) | |||
| if os.path.isfile(ckpt_name): | |||
| os.remove(ckpt_name) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_get_and_init_graph_cell_parameters_in_graph_mode(): | |||
| """ | |||
| Description: load mind ir and update parameters in graph mode. | |||
| Expectation: generate a graph with updated parameters. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| get_and_init_graph_cell_parameters() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_get_and_init_graph_cell_parameters_in_pynative_mode(): | |||
| """ | |||
| Description: load mind ir and update parameters in pynative mode. | |||
| Expectation: generate a graph with updated parameters. | |||
| """ | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| get_and_init_graph_cell_parameters() | |||
| @@ -13,18 +13,17 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test get and init GraphCell parameters""" | |||
| """test init GraphCell parameters with illegal data""" | |||
| import os | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import nn | |||
| from mindspore import context | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import export, load, save_checkpoint, load_checkpoint | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| from mindspore import context | |||
| from mindspore import export, load | |||
| from mindspore import nn | |||
| class Net(nn.Cell): | |||
| @@ -50,44 +49,17 @@ input_a = Tensor(np_a) | |||
| input_b = Tensor(np_b) | |||
| def load_mindir_and_update_params(): | |||
| net = Net() | |||
| mindir_name = "net_0.mindir" | |||
| export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') | |||
| load_net = nn.GraphCell(graph=load(mindir_name)) | |||
| ret = load_net(input_a, input_b) | |||
| assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param) | |||
| ckpt_name = "net_0.ckpt" | |||
| save_checkpoint(load_net, ckpt_name) | |||
| params_init = load_checkpoint(ckpt_name) | |||
| load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init) | |||
| return load_net_with_new_params | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_get_and_init_graph_cell_parameters(): | |||
| """ | |||
| Description: load mind ir and update parameters. | |||
| Expectation: generate a graph with updated parameters. | |||
| """ | |||
| load_net = load_mindir_and_update_params() | |||
| ret = load_net(input_a, input_b) | |||
| assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0)) | |||
| def remove_generated_file(file_name): | |||
| if os.path.isfile(file_name): | |||
| os.remove(file_name) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_init_graph_cell_parameters_with_wrong_type(): | |||
| """ | |||
| Description: load mind ir and update parameters with wrong type. | |||
| Expectation: raise a ValueError indicating the params type error. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| mindir_name = "net_1.mindir" | |||
| export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') | |||
| @@ -99,16 +71,15 @@ def test_init_graph_cell_parameters_with_wrong_type(): | |||
| load_net(input_a, input_b) | |||
| assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value) | |||
| remove_generated_file(mindir_name) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_init_graph_cell_parameters_with_wrong_shape(): | |||
| """ | |||
| Description: load mind ir and update parameters with wrong tensor shape. | |||
| Expectation: raise a ValueError indicating the tensor shape error. | |||
| """ | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net = Net() | |||
| mindir_name = "net_2.mindir" | |||
| export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') | |||
| @@ -120,16 +91,15 @@ def test_init_graph_cell_parameters_with_wrong_shape(): | |||
| load_net(input_a, input_b) | |||
| assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) | |||
| remove_generated_file(mindir_name) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_init_graph_cell_parameters_with_wrong_dtype(): | |||
| """ | |||
| Description: load mind ir and update parameters with wrong tensor dtype. | |||
| Expectation: raise a ValueError indicating the tensor dtype error. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| mindir_name = "net_3.mindir" | |||
| export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') | |||
| @@ -141,3 +111,4 @@ def test_init_graph_cell_parameters_with_wrong_dtype(): | |||
| load_net(input_a, input_b) | |||
| assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) | |||
| remove_generated_file(mindir_name) | |||