Browse Source

!30903 Fix Ckpt Shape Compile Error

Merge pull request !30903 from huangxinjing/fix_compile_shape_error
r1.7
i-robot Gitee 4 years ago
parent
commit
cb309af253
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 42 additions and 37 deletions
  1. +32
    -33
      mindspore/ccsrc/pybind_api/ir/param_info_py.cc
  2. +5
    -4
      mindspore/ccsrc/utils/parallel_context.cc
  3. +4
    -0
      mindspore/core/ir/param_info.h
  4. +1
    -0
      mindspore/python/mindspore/common/parameter.py

+ 32
- 33
mindspore/ccsrc/pybind_api/ir/param_info_py.cc View File

@@ -20,37 +20,36 @@
namespace mindspore {
namespace py = pybind11;

REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
(void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo")
.def(py::init())
.def("clone", &ParamInfo::Clone)
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
.def_property("requires_grad", &ParamInfo::requires_grad, &ParamInfo::set_requires_grad)
.def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server)
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel,
&ParamInfo::set_layerwise_parallel)
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
&ParamInfo::set_parallel_optimizer)
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def_property("parallel_optimizer_comm_recompute",
&ParamInfo::parallel_optimizer_comm_recompute,
&ParamInfo::set_parallel_optimizer_comm_recompute)
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
.def_property("requires_aggr", &ParamInfo::requires_aggr, &ParamInfo::set_requires_aggr)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 6) {
std::runtime_error("Invalid state for ParamInfo!");
}
ParamInfoPtr p = std::make_shared<ParamInfo>();
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());
return p;
}));
}));
REGISTER_PYBIND_DEFINE(
ParamInfo, ([](const py::module *m) {
(void)py::class_<ParamInfo, ParamInfoPtr>(*m, "ParamInfo")
.def(py::init())
.def("clone", &ParamInfo::Clone)
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
.def_property("requires_grad", &ParamInfo::requires_grad, &ParamInfo::set_requires_grad)
.def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server)
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, &ParamInfo::set_layerwise_parallel)
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, &ParamInfo::set_parallel_optimizer)
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def_property("parallel_optimizer_comm_recompute", &ParamInfo::parallel_optimizer_comm_recompute,
&ParamInfo::set_parallel_optimizer_comm_recompute)
.def_property("parameter_shape", &ParamInfo::parameter_shape, &ParamInfo::set_parameter_shape)
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
.def_property("requires_aggr", &ParamInfo::requires_aggr, &ParamInfo::set_requires_aggr)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 6) {
std::runtime_error("Invalid state for ParamInfo!");
}
ParamInfoPtr p = std::make_shared<ParamInfo>();
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());
return p;
}));
}));
} // namespace mindspore

+ 5
- 4
mindspore/ccsrc/utils/parallel_context.cc View File

@@ -270,12 +270,13 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
if (init_param_shape_) {
return;
}
auto iter = param_shapes.find(param_node->name());
if (iter == param_shapes.end()) {
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
auto param_info = param_node->param_info();
if (!param_info) return;
auto shape = param_info->parameter_shape();
if (shape.empty()) {
MS_LOG(WARNING) << "The parameter " << param_node->name() << "'s parameter_shape in param_info is empty";
return;
}
auto shape = iter->second;
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
ptr->set_shape(base_shape);
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;


+ 4
- 0
mindspore/core/ir/param_info.h View File

@@ -88,6 +88,9 @@ class ParamInfo {
parallel_optimizer_comm_recompute_ = parallel_optimizer_comm_recompute;
}

std::vector<int64_t> parameter_shape() const { return parameter_shape_; }
void set_parameter_shape(std::vector<int64_t> tensor_shape) { parameter_shape_ = tensor_shape; }

bool cache_enable() const { return cache_enable_; }
void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }

@@ -116,6 +119,7 @@ class ParamInfo {
std::vector<int64_t> cache_shape_;
ParameterPtr parameter_{nullptr};
bool requires_aggr_{true};
std::vector<int64_t> parameter_shape_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_

+ 1
- 0
mindspore/python/mindspore/common/parameter.py View File

@@ -208,6 +208,7 @@ class Parameter(Tensor_):
else:
raise TypeError(f"The type of the argument 'default_input' must be in ['Tensor', 'int', 'float',"
f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
self.param_info.parameter_shape = self.shape

def __deepcopy__(self, memodict):
new_obj = Parameter(self)


Loading…
Cancel
Save