Browse Source

!3436 fix mix precision operator issue

Merge pull request !3436 from wangqiuliang/fix-mix-precision-r0.6
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c5e6cfebe7
6 changed files with 29 additions and 3 deletions
  1. +11
    -0
      mindspore/_extends/builtin_operations.py
  2. +1
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  3. +8
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  4. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  5. +1
    -0
      mindspore/nn/wrap/loss_scale.py
  6. +7
    -0
      tests/ut/python/ops/test_control_ops.py

+ 11
- 0
mindspore/_extends/builtin_operations.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype

@@ -171,3 +172,13 @@ def tuple_to_array(x):
def stop_gradient(x):
"""Implement `stop_gradient`."""
return x


def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)

+ 1
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -59,7 +59,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);

const std::set<std::string> ignore_infer_prim = {"make_ref"};
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative
} // namespace mindspore



+ 8
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;

const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};

namespace mindspore {
namespace pynative {
@@ -766,6 +766,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
}
MS_LOG(DEBUG) << "Newgraph already compiled";
return;
}
@@ -774,6 +777,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg

if (top_g_ == nullptr) {
top_g_ = curr_g_ = g;
resource_ = std::make_shared<pipeline::Resource>();
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
Pushp();
@@ -1075,6 +1080,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG(INFO) << "Clear res";
(void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag);
(void)cell_resource_map_.erase(flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
@@ -1086,6 +1092,7 @@ void PynativeExecutor::Clear(const std::string &flag) {

MS_LOG(INFO) << "Clear";
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
graph_info_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
@@ -1095,7 +1102,6 @@ void PynativeExecutor::Clean() {
MS_LOG(INFO) << "Clean all res";
Clear();
grad_flag_ = false;
df_builder_ = nullptr;
ad::CleanRes();
pipeline::ReclaimOptimizer();
}


+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -115,6 +115,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;


+ 1
- 0
mindspore/nn/wrap/loss_scale.py View File

@@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer


+ 7
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@@ -639,3 +640,9 @@ def test_large_for_loop_with_continue_break():
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)


def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
assert z.dtype == mstype.float16

Loading…
Cancel
Save