Browse Source

fix mix precesion operator issue

tags/v0.7.0-beta
kingfo 5 years ago
parent
commit
73ea9b7855
7 changed files with 40 additions and 13 deletions
  1. +10
    -0
      mindspore/_extends/builtin_operations.py
  2. +1
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  3. +8
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  4. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  5. +12
    -10
      mindspore/nn/cell.py
  6. +1
    -0
      mindspore/nn/wrap/loss_scale.py
  7. +7
    -0
      tests/ut/python/ops/test_control_ops.py

+ 10
- 0
mindspore/_extends/builtin_operations.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype

@@ -171,3 +172,12 @@ def tuple_to_array(x):
def stop_gradient(x):
"""Implement `stop_gradient`."""
return x

def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)

+ 1
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -61,7 +61,7 @@ struct OpExecInfo {
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);

const std::set<std::string> ignore_infer_prim = {"make_ref"};
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative
} // namespace mindspore



+ 8
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;

const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};

namespace mindspore {
namespace pynative {
@@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
resource_ = cell_resource_map_[cell_id];
}
MS_LOG(DEBUG) << "Newgraph already compiled";
return;
}
@@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg

if (top_g_ == nullptr) {
top_g_ = curr_g_ = g;
resource_ = std::make_shared<pipeline::Resource>();
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
Pushp();
@@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG(DEBUG) << "Clear res";
(void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag);
(void)cell_resource_map_.erase(flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
@@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {

MS_LOG(DEBUG) << "Clear";
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
@@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
Clear();
grad_flag_ = false;
op_forward_map_.clear();
df_builder_ = nullptr;
ad::CleanRes();
pipeline::ReclaimOptimizer();
}


+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool grad_flag_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;


+ 12
- 10
mindspore/nn/cell.py View File

@@ -240,12 +240,13 @@ class Cell:
else:
_pynative_exec.set_grad_flag(False)
cast_inputs = list()
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
@@ -496,10 +497,11 @@ class Cell:
Args:
param (Parameter): The parameter to cast.
"""
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
return param

def insert_child_to_cell(self, child_name, child):


+ 1
- 0
mindspore/nn/wrap/loss_scale.py View File

@@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer


+ 7
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break():
t = Tensor(np.ones([2, 3], dtype=np.float32))
net = Net()
net(t)


def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)
assert z.dtype == mstype.float16

Loading…
Cancel
Save