浏览代码

fix bug in parameter set & fix code style in pynative_executa.cc

tags/v1.0.0
Wei Luning 5 年前
父节点
当前提交
cdbd16de0c
共有 6 个文件被更改,包括 20 次插入8 次删除
  1. +12
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +1
    -1
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  3. +0
    -1
      mindspore/common/api.py
  4. +1
    -1
      mindspore/core/ir/tensor.h
  5. +5
    -2
      mindspore/nn/cell.py
  6. +1
    -1
      tests/ut/python/ops/test_math_ops.py

+ 12
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc 查看文件

@@ -288,6 +288,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
} }


bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) { bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
MS_EXCEPTION_IF_NULL(dtypes);
auto signature = prim->signatures(); auto signature = prim->signatures();
bool has_sig_dtype = false; bool has_sig_dtype = false;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
@@ -733,20 +734,29 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {


AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) { abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
CNodePtr cnode = nullptr; CNodePtr cnode = nullptr;
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;

auto prim = op_exec_info->py_primitive; auto prim = op_exec_info->py_primitive;
const auto &signature = prim->signatures();

inputs.push_back(NewValueNode(prim)); inputs.push_back(NewValueNode(prim));


size_t size = op_exec_info->op_inputs.size(); size_t size = op_exec_info->op_inputs.size();
auto sig_size = signature.size();
// ignore signature for cast op // ignore signature for cast op
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
<< "inputs size " << sig_size;
}
bool is_cast_op = (op_exec_info->op_name == "Cast"); bool is_cast_op = (op_exec_info->op_name == "Cast");
if (!is_cast_op) { if (!is_cast_op) {
const auto &signature = prim->signatures();
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i]; auto obj = op_exec_info->op_inputs[i];
auto sig = SignatureEnumRW::kRWDefault; auto sig = SignatureEnumRW::kRWDefault;
if (signature.size() > 0) {
if (sig_size > 0) {
sig = signature[i].rw; sig = signature[i].rw;
} }
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " "


+ 1
- 1
mindspore/ccsrc/pybind_api/ir/tensor_py.cc 查看文件

@@ -455,7 +455,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32) >>> data.set_dtype(mindspore.int32)
mindspore.int32 mindspore.int32
)mydelimiter") )mydelimiter")
.def("set_cast_dtype", &Tensor::set_cast_dtype)
.def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr)
.def("__str__", &Tensor::ToString) .def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle( .def(py::pickle(


+ 0
- 1
mindspore/common/api.py 查看文件

@@ -292,7 +292,6 @@ class _PynativeExecutor:
def __init__(self): def __init__(self):
self._executor = PynativeExecutor_.get_instance() self._executor = PynativeExecutor_.get_instance()


#TODO(kpy):add a type arg
def new_graph(self, obj, *args, **kwargs): def new_graph(self, obj, *args, **kwargs):
self._executor.new_graph(obj, *args, *(kwargs.values())) self._executor.new_graph(obj, *args, *(kwargs.values()))




+ 1
- 1
mindspore/core/ir/tensor.h 查看文件

@@ -269,7 +269,7 @@ class Tensor : public MetaTensor {


std::string id() const { return id_; } std::string id() const { return id_; }
TypePtr cast_dtype() { return cast_dtype_; } TypePtr cast_dtype() { return cast_dtype_; }
void set_cast_dtype(TypePtr dtype) { cast_dtype_ = dtype; }
void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }


void SetNeedWait(bool need_wait) { void SetNeedWait(bool need_wait) {
if (event_ != nullptr) { if (event_ != nullptr) {


+ 5
- 2
mindspore/nn/cell.py 查看文件

@@ -582,10 +582,13 @@ class Cell(Cell_):
param (Parameter): The parameter to cast. param (Parameter): The parameter to cast.
""" """
if hasattr(self, "_mindspore_flags"): if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
param.set_cast_dtype(mstype.float16)
if self._mindspore_flags.get('fp32'): if self._mindspore_flags.get('fp32'):
param.set_cast_dtype(mstype.float32) param.set_cast_dtype(mstype.float32)
elif self._mindspore_flags.get('fp16'):
param.set_cast_dtype(mstype.float16)
else:
# retest dtype
param.set_cast_dtype()
return param return param


def insert_child_to_cell(self, child_name, child): def insert_child_to_cell(self, child_name, child):


+ 1
- 1
tests/ut/python/ops/test_math_ops.py 查看文件

@@ -464,7 +464,7 @@ raise_set = [
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
'desc_inputs': [0]}), 'desc_inputs': [0]}),
('AssignAdd_Error', { ('AssignAdd_Error', {
'block': (P.AssignAdd(), {'exception': IndexError}),
'block': (P.AssignAdd(), {'exception': ValueError}),
'desc_inputs': [[1]]}), 'desc_inputs': [[1]]}),
] ]




正在加载...
取消
保存