|
|
|
@@ -89,12 +89,18 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; |
|
|
|
class Optimizer : public std::enable_shared_from_this<Optimizer> { |
|
|
|
public: |
|
|
|
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) |
|
|
|
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} |
|
|
|
: name_(name), |
|
|
|
resource_(resource_ptr), |
|
|
|
run_only_once_(false), |
|
|
|
is_watch_renormalize_(false), |
|
|
|
is_enable_(true), |
|
|
|
is_untyped_generated_(false) {} |
|
|
|
virtual ~Optimizer() = default; |
|
|
|
|
|
|
|
void Init(const OptPassGroupMap &passes, bool run_only_once) { |
|
|
|
run_only_once_ = run_only_once; |
|
|
|
is_watch_renormalize_ = false; |
|
|
|
is_untyped_generated_ = false; |
|
|
|
is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); |
|
|
|
|
|
|
|
for (auto &iter : passes) { |
|
|
|
@@ -154,14 +160,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { |
|
|
|
// So generate the args_spec from parameters. |
|
|
|
abstract::AbstractBasePtrList maybe_new_args_spec; |
|
|
|
if (is_watch_renormalize_) { |
|
|
|
if (untyped_nodes_.size() > 0) { |
|
|
|
if (is_untyped_generated_) { |
|
|
|
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), |
|
|
|
std::back_inserter(maybe_new_args_spec), |
|
|
|
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); |
|
|
|
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); |
|
|
|
clear_untyped_nodes(); |
|
|
|
clear_is_untyped_generated(); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; |
|
|
|
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), |
|
|
|
@@ -206,13 +212,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { |
|
|
|
|
|
|
|
const std::string name() const { return name_; } |
|
|
|
|
|
|
|
void add_node_to_renormalize(AnfNodePtr anode) { |
|
|
|
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { |
|
|
|
untyped_nodes_.push_back(anode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void clear_untyped_nodes() { untyped_nodes_.clear(); } |
|
|
|
void set_is_untyped_generated() { is_untyped_generated_ = true; } |
|
|
|
void clear_is_untyped_generated() { is_untyped_generated_ = false; } |
|
|
|
|
|
|
|
void enable_watch_renormalize() { is_watch_renormalize_ = true; } |
|
|
|
void disable_watch_renormalize() { is_watch_renormalize_ = false; } |
|
|
|
@@ -232,9 +233,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { |
|
|
|
std::vector<OptPass> passes_; |
|
|
|
std::vector<std::string> pass_names_; |
|
|
|
bool run_only_once_; |
|
|
|
std::vector<AnfNodePtr> untyped_nodes_; |
|
|
|
bool is_watch_renormalize_; |
|
|
|
bool is_enable_; |
|
|
|
bool is_untyped_generated_; |
|
|
|
}; |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |
|
|
|
|