Browse Source

fix mix precesion operator issue

tags/v0.7.0-beta
kingfo 5 years ago
parent
commit
73ea9b7855
7 changed files with 40 additions and 13 deletions
  1. +10
    -0
      mindspore/_extends/builtin_operations.py
  2. +1
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  3. +8
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  4. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  5. +12
    -10
      mindspore/nn/cell.py
  6. +1
    -0
      mindspore/nn/wrap/loss_scale.py
  7. +7
    -0
      tests/ut/python/ops/test_control_ops.py

+ 10
- 0
mindspore/_extends/builtin_operations.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""builtin_operations""" """builtin_operations"""
import numpy as np import numpy as np
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype


@@ -171,3 +172,12 @@ def tuple_to_array(x):
def stop_gradient(x): def stop_gradient(x):
"""Implement `stop_gradient`.""" """Implement `stop_gradient`."""
return x return x

def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)

+ 1
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -61,7 +61,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);


const std::set<std::string> ignore_infer_prim = {"make_ref"};
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore




+ 8
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;


const char SINGLE_OP_GRAPH[] = "single_op_graph"; const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode // primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};


namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
@@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell); auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) { if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
}
MS_LOG(DEBUG) << "Newgraph already compiled"; MS_LOG(DEBUG) << "Newgraph already compiled";
return; return;
} }
@@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg


if (top_g_ == nullptr) { if (top_g_ == nullptr) {
top_g_ = curr_g_ = g; top_g_ = curr_g_ = g;
resource_ = std::make_shared<pipeline::Resource>();
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>(); df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get(); MS_LOG(DEBUG) << "First new graph" << top_g_.get();
Pushp(); Pushp();
@@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG(DEBUG) << "Clear res"; MS_LOG(DEBUG) << "Clear res";
(void)graph_map_.erase(flag); (void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag); (void)cell_graph_map_.erase(flag);
(void)cell_resource_map_.erase(flag);
Clean(); Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag. // Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
@@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {


MS_LOG(DEBUG) << "Clear"; MS_LOG(DEBUG) << "Clear";
top_g_ = nullptr; top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr; curr_g_ = nullptr;
graph_info_map_.clear(); graph_info_map_.clear();
op_id_map_.clear(); op_id_map_.clear();
@@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
Clear(); Clear();
grad_flag_ = false; grad_flag_ = false;
op_forward_map_.clear(); op_forward_map_.clear();
df_builder_ = nullptr;
ad::CleanRes(); ad::CleanRes();
pipeline::ReclaimOptimizer(); pipeline::ReclaimOptimizer();
} }


+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool grad_flag_; bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_; std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_; std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_; std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_; std::unordered_map<std::string, size_t> op_id_map_;


+ 12
- 10
mindspore/nn/cell.py View File

@@ -240,12 +240,13 @@ class Cell:
else: else:
_pynative_exec.set_grad_flag(False) _pynative_exec.set_grad_flag(False)
cast_inputs = list() cast_inputs = list()
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs: if cast_inputs:
cast_inputs = tuple(cast_inputs) cast_inputs = tuple(cast_inputs)
else: else:
@@ -496,10 +497,11 @@ class Cell:
Args: Args:
param (Parameter): The parameter to cast. param (Parameter): The parameter to cast.
""" """
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
return param return param


def insert_child_to_cell(self, child_name, child): def insert_child_to_cell(self, child_name, child):


+ 1
- 0
mindspore/nn/wrap/loss_scale.py View File

@@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer


+ 7
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
@@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break():
t = Tensor(np.ones([2, 3], dtype=np.float32)) t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net() net = Net()
net(t) net(t)


def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
assert z.dtype == mstype.float16

Loading…
Cancel
Save