Browse Source

!11138 Opitimize pynative dynamic grad graph

From: @zjun3021
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4ed3491952
6 changed files with 573 additions and 296 deletions
  1. +466
    -228
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +67
    -36
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  3. +3
    -0
      mindspore/common/api.py
  4. +26
    -0
      mindspore/core/utils/ordered_map.h
  5. +1
    -15
      mindspore/nn/cell.py
  6. +10
    -17
      mindspore/ops/composite/base.py

+ 466
- 228
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
File diff suppressed because it is too large
View File


+ 67
- 36
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -69,32 +69,47 @@ struct GraphInfo {
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
}; };


struct CellInfo {
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
FuncGraphPtr fg; // Forward graph
std::string cell_id;
std::string bprop_cell_id;
class CellInfo {
public:
CellInfo() = default; CellInfo() = default;
CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_grad(isgrad),
is_custom_bprop(custom_bprop),
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_custom_bprop(custom_bprop),
is_dynamic(has_dynamic),
fg(std::move(foward_graph)), fg(std::move(foward_graph)),
cell_id(std::move(cellid)), cell_id(std::move(cellid)),
bprop_cell_id(std::move(bprop_id)) {} bprop_cell_id(std::move(bprop_id)) {}
};


struct TopCellInfo {
ResourcePtr resource;
FuncGraphPtr df_builder;
FuncGraphPtr bg; // Backward graph
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
bool is_dynamic{false}; // Set by has_dynamic_cell
bool is_real_dynamic{false}; // Set by ops order
size_t call_times{0};
FuncGraphPtr fg{nullptr}; // Forward graph
std::string cell_id; std::string cell_id;
bool is_dynamic_cell{false};
std::string bprop_cell_id;
std::vector<std::string> cell_ops_info; // All ops info
};

class TopCellInfo {
public:
TopCellInfo() = default; TopCellInfo() = default;
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {}

bool is_topest{false};
bool do_vm_compiled{false};
ResourcePtr resource{nullptr};
FuncGraphPtr df_builder{nullptr};
FuncGraphPtr bg{nullptr}; // Backward graph
std::string cell_id;
std::string sens_id;
std::string weights_id;
}; };


using GraphInfoPtr = std::shared_ptr<GraphInfo>;
using CellInfoPtr = std::shared_ptr<CellInfo>;
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;

class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
public: public:
static std::shared_ptr<PynativeExecutor> GetInstance() { static std::shared_ptr<PynativeExecutor> GetInstance() {
@@ -119,11 +134,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void NewGraph(const py::object &cell, const py::args &args); void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
py::object CheckGraph(const py::object &cell, const py::args &args); py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);


// Get info // Get info
bool GetIsDynamicCell() const { return dynamic_cell_; }
bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); }
// Call by python // Call by python
void Clear(const std::string &flag = ""); void Clear(const std::string &flag = "");
void Clean(); void Clean();
@@ -149,7 +165,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
template <typename T> template <typename T>
void VectorClear(T *vec, const std::string &cell_id) { void VectorClear(T *vec, const std::string &cell_id) {
for (auto it = vec->begin(); it != vec->end();) { for (auto it = vec->begin(); it != vec->end();) {
if (it->cell_id.find(cell_id) != std::string::npos) {
if ((*it)->cell_id.find(cell_id) != std::string::npos) {
it = vec->erase(it); it = vec->erase(it);
} else { } else {
it++; it++;
@@ -201,29 +217,39 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveTensorsInValueNode(const ResourcePtr &resource); void SaveTensorsInValueNode(const ResourcePtr &resource);
void SaveAllValueNodeTensors(const FuncGraphPtr &graph); void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
void CleanPreMemoryInValueNode(const std::string &cell_id);
void CleanPreMemoryInValueNode();


// Construct grad graph // Construct grad graph
void PushCurrentGraphToStack(); void PushCurrentGraphToStack();
void PopGraphStack(); void PopGraphStack();
void PushCurrentCellOpInfoToStack();
void PopCurrentCellOpInfoFromStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradOrder() { ++grad_order_; } void AddNestedGradOrder() { ++grad_order_; }
void SubNestedGradOrder(); void SubNestedGradOrder();
bool IsNotNestedGrad() const;
bool IsNestedGrad() const;
bool IsTopGraph(const std::string &cell_id); bool IsTopGraph(const std::string &cell_id);
bool IsTopestGraph(const std::string &cell_id);
bool IsBpropGraph(const std::string &cell_id); bool IsBpropGraph(const std::string &cell_id);
bool IsFirstGradStep(const std::string &cell_id);
bool grad_running() const { return grad_is_running_; } bool grad_running() const { return grad_is_running_; }
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; } bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false); bool need_cloned = false, bool is_grad = false);
void ClearCnodeRes(const AnfNodePtr &node);
void UpdateCellDynamic(const std::string &cell_id);
bool CheckCellChanged(const std::string &cell_id);
void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled);
void ClearResidualRes(const std::string &cell_id); void ClearResidualRes(const std::string &cell_id);
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void NewGraphInner(const py::object &cell, const py::args &args); void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void MakeNewTopGraph(const string &cell_id, const py::args &args);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args); const std::string &out_id, const py::args &args);
@@ -232,38 +258,44 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
const std::string &cell_id, const py::args &args); const std::string &cell_id, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr); py::object *sens = nullptr);
void ClearDynamicTopRes(const std::string &cell_id);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args); const py::args &args);
std::string GetCellId(const py::object &obj, const py::args &args); std::string GetCellId(const py::object &obj, const py::args &args);
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
std::string GetTensorCellId(const std::string &cell_id);
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights, void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size, const std::string &cell_id); size_t arg_size, const std::string &cell_id);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
void UpdateGraphInfoMap(const std::string &cell_id);
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
const std::string &cell_id);
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens); const py::object &out, bool has_sens);
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);


// Hold graph(forward and grad) info // Hold graph(forward and grad) info
std::string GetCellOpInfo();
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
graph_info_map_[g].objects.push_back(obj);
graph_info_map_[g]->objects.push_back(obj);
} }
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false); bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) { void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
graph_info_map_[g].params[id] = param;
graph_info_map_[g]->params[id] = param;
} }
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
int64_t index = -1) { int64_t index = -1) {
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
} }
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) { const std::vector<int64_t> &index) {
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
graph_info_map_[g]->node_map[id] = std::make_pair(node, index);
} }
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
const std::vector<int64_t> &index_sequence, bool is_param = false); const std::vector<int64_t> &index_sequence, bool is_param = false);
@@ -274,7 +306,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t grad_order_{0}; size_t grad_order_{0};
std::string top_cell_id_; std::string top_cell_id_;
bool grad_flag_{false}; bool grad_flag_{false};
bool dynamic_cell_{false};
bool has_dynamic_cell_{false};
bool grad_is_running_{false}; bool grad_is_running_{false};
bool need_replace_forward_{true}; bool need_replace_forward_{true};
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
@@ -288,16 +320,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr curr_g_{nullptr}; FuncGraphPtr curr_g_{nullptr};
// Records forwrad graph, the bottom is top graph // Records forwrad graph, the bottom is top graph
std::stack<FuncGraphPtr> graph_stack_; std::stack<FuncGraphPtr> graph_stack_;
// Records op info of every cell, the bottom is op info of top cell
std::stack<std::string> cell_op_info_stack_;


// Use vector for keep order // Use vector for keep order
std::vector<CellInfo> cell_graph_list_;
std::vector<TopCellInfo> top_cell_list_;
std::vector<CellInfoPtr> cell_graph_list_;
std::vector<TopCellInfoPtr> top_cell_list_;
std::unordered_set<std::string> cell_input_args_; std::unordered_set<std::string> cell_input_args_;
std::unordered_map<std::string, bool> cell_dynamic_map_;
// Record all info for all cells // Record all info for all cells
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_; std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;


// Used for runop and replace forward result of grad graph // Used for runop and replace forward result of grad graph


+ 3
- 0
mindspore/common/api.py View File

@@ -314,6 +314,9 @@ class _PynativeExecutor:
def check_graph(self, obj, *args, **kwargs): def check_graph(self, obj, *args, **kwargs):
return self._executor.check_graph(obj, *args, *(kwargs.values())) return self._executor.check_graph(obj, *args, *(kwargs.values()))


def check_run(self, obj, *args, **kwargs):
return self._executor.check_run(obj, *args, *(kwargs.values()))

def grad(self, grad, obj, weights, *args, **kwargs): def grad(self, grad, obj, weights, *args, **kwargs):
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))




+ 26
- 0
mindspore/core/utils/ordered_map.h View File

@@ -162,6 +162,14 @@ class OrderedMap {
return pos == map_data_.end() ? sequential_data_.end() : (pos->second); return pos == map_data_.end() ? sequential_data_.end() : (pos->second);
} }


ValueT at(const key_t &key) {
auto pos = map_data_.find(key);
if (pos == map_data_.end()) {
MS_LOG(EXCEPTION) << "Have no key " << key;
}
return pos->second->second;
}

// Remove the last element from the sequential_data_. // Remove the last element from the sequential_data_.
void pop_back() { void pop_back() {
typename map_type::iterator pos = map_data_.find(sequential_data_.back().first); typename map_type::iterator pos = map_data_.find(sequential_data_.back().first);
@@ -192,6 +200,24 @@ class OrderedMap {
return 1; return 1;
} }


void update(const key_t &old_key, const key_t &new_key) {
auto old_it = find(old_key);
if (old_it == end()) {
return;
}
auto new_it = find(new_key);
if (new_it == end()) {
old_it->first = new_key;
auto nh = map_data_.extract(old_key);
nh.key() = new_key;
map_data_.insert(std::move(nh));
return;
}
*old_it = *new_it;
(void)erase(old_key);
(void)erase(new_key);
}

private: private:
map_type map_data_; map_type map_data_;
sequential_type sequential_data_; sequential_type sequential_data_;


+ 1
- 15
mindspore/nn/cell.py View File

@@ -68,7 +68,7 @@ class Cell(Cell_):
""" """
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad', '_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
'_auto_parallel_compile_and_run', 'cell_type'] '_auto_parallel_compile_and_run', 'cell_type']
@@ -105,15 +105,10 @@ class Cell(Cell_):
self._backward_hook = None self._backward_hook = None
self.enable_hook = False self.enable_hook = False
self._bprop_debug = False self._bprop_debug = False
self._already_run = False
self.cell_type = None self.cell_type = None
self._auto_parallel_compile_and_run = False self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False self._support_non_tensor_inputs = False


@property
def already_run(self):
return self._already_run

def __getstate__(self): def __getstate__(self):
base = Cell_.__getstate__(self) base = Cell_.__getstate__(self)
return base, self.__dict__ return base, self.__dict__
@@ -150,10 +145,6 @@ class Cell(Cell_):
# `<class 'xxxxxxx'>` to `xxxxxxx` # `<class 'xxxxxxx'>` to `xxxxxxx`
return str(self.__class__)[8:-2] return str(self.__class__)[8:-2]


@already_run.setter
def already_run(self, value):
self._already_run = value

@property @property
def create_time(self): def create_time(self):
return self._create_time return self._create_time
@@ -334,12 +325,10 @@ class Cell(Cell_):
for item in inputs: for item in inputs:
if isinstance(item, numpy.ndarray): if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.") raise TypeError("cell inputs should not be numpy array.")
origin_grad = []
if self.requires_grad is True: if self.requires_grad is True:
_pynative_exec.set_grad_flag(True) _pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs, **kwargs) _pynative_exec.new_graph(self, *inputs, **kwargs)
for cell in self.cells(): for cell in self.cells():
origin_grad.append(cell.requires_grad)
cell.set_grad(True) cell.set_grad(True)
else: else:
_pynative_exec.set_grad_flag(False) _pynative_exec.set_grad_flag(False)
@@ -363,9 +352,6 @@ class Cell(Cell_):
output = output.data output = output.data
if self.requires_grad is True: if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs, **kwargs) _pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
return output return output


def _add_attr(self, name, value): def _add_attr(self, name, value):


+ 10
- 17
mindspore/ops/composite/base.py View File

@@ -319,36 +319,30 @@ class GradOperation(GradOperation_):
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param) GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param)
self.grad_fn = None self.grad_fn = None
self.fn = None self.fn = None
self.need_forward = False


def _pynative_forward_run(self, args, kwargs, fn): def _pynative_forward_run(self, args, kwargs, fn):
""" Pynative forward run to build grad graph. """ """ Pynative forward run to build grad graph. """
new_kwargs = {}
new_kwargs = kwargs
if self.sens_param: if self.sens_param:
if not 'sens' in kwargs.keys(): if not 'sens' in kwargs.keys():
args = args[:-1] args = args[:-1]
new_kwargs = kwargs
else: else:
for key, value in kwargs.items():
if key != 'sens':
new_kwargs[key] = value
new_kwargs = kwargs.copy()
new_kwargs.pop('sens')
for arg in args: for arg in args:
if not isinstance(arg, Tensor): if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode") raise TypeError("grad inputs should be tensor in pynative mode")
if isinstance(fn, FunctionType): if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
else: else:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.already_run:
self.need_forward = True
if self.need_forward:
# Check if fn have run already
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
fn.set_grad() fn.set_grad()
fn(*args, **new_kwargs) fn(*args, **new_kwargs)
fn.already_run = False


def __call__(self, fn, weights=None): def __call__(self, fn, weights=None):
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param) grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
@@ -367,7 +361,6 @@ class GradOperation(GradOperation_):
def after_grad(*args, **kwargs): def after_grad(*args, **kwargs):
if _pynative_exec.check_graph(fn, *args, **kwargs): if _pynative_exec.check_graph(fn, *args, **kwargs):
print("Another grad step is running") print("Another grad step is running")
fn.already_run = False
self._pynative_forward_run(args, kwargs, fn) self._pynative_forward_run(args, kwargs, fn)
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs) _pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
out = _pynative_exec(fn, *args, **kwargs) out = _pynative_exec(fn, *args, **kwargs)


Loading…
Cancel
Save