Browse Source

refine code* refine code in bert model* add ToAbstruct for `FuncGraph`, `MetaFuncGraph` `Primitive`* remove partial hard code in spec for poly* remove any in data convert cache

tags/v0.7.0-beta
Wei Luning 5 years ago
parent
commit
484d7f10c8
14 changed files with 57 additions and 33 deletions
  1. +2
    -2
      mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h
  2. +11
    -10
      mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
  3. +2
    -2
      mindspore/ccsrc/pipeline/jit/parse/data_converter.h
  4. +4
    -2
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  5. +6
    -0
      mindspore/core/ir/func_graph.cc
  6. +1
    -0
      mindspore/core/ir/func_graph.h
  7. +6
    -0
      mindspore/core/ir/meta_func_graph.cc
  8. +1
    -1
      mindspore/core/ir/meta_func_graph.h
  9. +7
    -0
      mindspore/core/ir/primitive.cc
  10. +1
    -1
      mindspore/core/ir/primitive.h
  11. +1
    -1
      mindspore/train/amp.py
  12. +3
    -2
      tests/perf_test/bert/test_bert_train.py
  13. +10
    -12
      tests/st/networks/models/bert/src/bert_model.py
  14. +2
    -0
      tests/ut/python/ops/test_ops_attr_infer.py

+ 2
- 2
mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h View File

@@ -42,8 +42,8 @@ class GetRefParamEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
return nullptr;
}
};


+ 11
- 10
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc View File

@@ -128,7 +128,8 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
std::vector<std::pair<std::string, ValuePtr>> key_values;
for (auto item : dict_values) {
if (!py::isinstance<py::str>(item.first)) {
MS_LOG(EXCEPTION) << "The key of dict is only support str.";
MS_LOG(ERROR) << "The key of dict is only support str.";
return false;
}
std::string key = py::str(item.first);
ValuePtr out = nullptr;
@@ -158,7 +159,7 @@ void ConvertDataClass(py::object obj, ValuePtr *const data) {
}

bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
MS_LOG(DEBUG) << "Converting primitive object";
MS_LOG(DEBUG) << "Converting primitive object" << use_signature;

// need check the primitive is class type or instance
auto obj_type = data_converter::GetObjType(obj);
@@ -184,6 +185,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature =
} else {
*data = primitive;
}
MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
}
return true;
}
@@ -389,12 +391,12 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
std::string obj_id = results[0] + python_mod_get_parse_method;
std::string obj_key = results[1];
FuncGraphPtr func_graph = nullptr;
Any value = Any();
ValuePtr value = nullptr;
bool is_cache = data_converter::GetObjectValue(obj_id, &value);
if (is_cache) {
if (value.is<FuncGraphPtr>()) {
if (value && value->isa<FuncGraph>()) {
MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id;
func_graph = value.cast<FuncGraphPtr>();
func_graph = value->cast<FuncGraphPtr>();
return func_graph;
}
}
@@ -415,10 +417,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
return func_graph;
}
namespace data_converter {
static std::unordered_map<std::string, Any> object_map_ = std::unordered_map<std::string, Any>();
static std::unordered_map<std::string, ValuePtr> object_map_;

static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_ =
std::unordered_map<std::string, std::vector<FuncGraphPtr>>();
static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;

void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
object_graphs_map_[obj_key].push_back(data);
@@ -430,8 +431,8 @@ const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs()
return object_graphs_map_;
}

void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; }
bool GetObjectValue(const std::string &obj_key, Any *const data) {
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; }
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) {
if (object_map_.count(obj_key)) {
*data = object_map_[obj_key];
return true;


+ 2
- 2
mindspore/ccsrc/pipeline/jit/parse/data_converter.h View File

@@ -32,8 +32,8 @@ namespace mindspore {
namespace parse {
// data convert for parse
namespace data_converter {
void CacheObjectValue(const std::string &obj_key, const Any &data);
bool GetObjectValue(const std::string &obj_key, Any *const data);
void CacheObjectValue(const std::string &obj_key, const ValuePtr &data);
bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);

void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);



+ 4
- 2
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -82,6 +82,9 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize
if (iter != specializations_.end()) {
return iter->second;
}
if (context->func_graph()) {
MS_LOG(EXCEPTION) << "Specialize inner error";
}
return nullptr;
}

@@ -539,8 +542,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
func->abstract()->isa<PartialAbstractClosure>()))) {
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}


+ 6
- 0
mindspore/core/ir/func_graph.cc View File

@@ -26,6 +26,7 @@
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/convert_utils_base.h"
#include "abstract/abstract_function.h"

namespace mindspore {
/*
@@ -48,6 +49,11 @@ FuncGraph::FuncGraph()
debug_info_ = std::make_shared<GraphDebugInfo>();
}

abstract::AbstractBasePtr FuncGraph::ToAbstract() {
auto temp_context = abstract::AnalysisContext::DummyContext();
return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
}

AnfNodePtr FuncGraph::output() const {
// If return value is set, return should have two inputs.
if (return_ != nullptr && return_->inputs().size() == 2) {


+ 1
- 0
mindspore/core/ir/func_graph.h View File

@@ -149,6 +149,7 @@ class FuncGraph : public FuncGraphBase {

// get the graph's abstract
abstract::AbstractFunctionPtr abstract();
abstract::AbstractBasePtr ToAbstract() override;

// return the graph's output, or nullptr if not yet deduced
AnfNodePtr output() const;


+ 6
- 0
mindspore/core/ir/meta_func_graph.cc View File

@@ -19,9 +19,15 @@
#include "ir/meta_func_graph.h"
#include "base/core_ops.h"
#include "utils/context/ms_context.h"
#include "abstract/abstract_function.h"

// namespace to support intermediate representation definition
namespace mindspore {

abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
return std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>());
}

FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);


+ 1
- 1
mindspore/core/ir/meta_func_graph.h View File

@@ -49,7 +49,7 @@ class MetaFuncGraph : public FuncGraphBase {
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list;
}
abstract::AbstractBasePtr ToAbstract() override;
const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments.


+ 7
- 0
mindspore/core/ir/primitive.cc View File

@@ -17,8 +17,15 @@
#include "ir/primitive.h"

#include <utility>
#include "abstract/abstract_function.h"


namespace mindspore {

abstract::AbstractBasePtr Primitive::ToAbstract() {
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
}

bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive &>(other);


+ 1
- 1
mindspore/core/ir/primitive.h View File

@@ -57,7 +57,7 @@ class Primitive : public Named {
record_evaluate_add_attr_(false) {}

MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract();
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
void BeginRecordAddAttr() {


+ 1
- 1
mindspore/train/amp.py View File

@@ -102,7 +102,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
def construct(self, data, label):
out = self._backbone(data)
label = F.mixed_precision_cast(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)

validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
if cast_model_type == mstype.float16:


+ 3
- 2
tests/perf_test/bert/test_bert_train.py View File

@@ -25,7 +25,8 @@ from mindspore import Tensor
from mindspore.nn.optim import AdamWeightDecay
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.nn import learning_rate_schedule as lr_schedules
from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.ops import operations as P
from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from ...dataset_mock import MindData
from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph

@@ -100,7 +101,7 @@ def get_config(version='base', batch_size=1):


class BertLearningRate(lr_schedules.LearningRateSchedule):
def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
super(BertLearningRate, self).__init__()
self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)


+ 10
- 12
tests/st/networks/models/bert/src/bert_model.py View File

@@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self._max_relative_position = max_relative_position
self._min_relative_position = -max_relative_position
self.range_length = -length + 1

self.tile = P.Tile()
@@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.one_hot = nn.OneHot(depth=self.vocab_size)
self.shape = P.Shape()
self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul()
@@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
flat_relative_positions_matrix)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
@@ -372,11 +370,11 @@ class SaturateCast(nn.Cell):
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
min_type = float(np.finfo(np_type).min)
max_type = float(np.finfo(np_type).max)

self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.tensor_min_type = min_type
self.tensor_max_type = max_type

self.min_op = P.Minimum()
self.max_op = P.Maximum()
@@ -442,7 +440,7 @@ class BertAttention(nn.Cell):
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions

self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
@@ -471,7 +469,7 @@ class BertAttention(nn.Cell):
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.multiply_data = -10000.0
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()



+ 2
- 0
tests/ut/python/ops/test_ops_attr_infer.py View File

@@ -15,6 +15,7 @@
""" test nn ops """
import numpy as np
from numpy.random import normal
import pytest

import mindspore.nn as nn
import mindspore.context as context
@@ -311,6 +312,7 @@ def test_op_with_arg_as_input():

# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
@pytest.mark.skip("poly in infer")
def test_partial_as_arg():
class PartialArgNet(nn.Cell):
def __init__(self):


Loading…
Cancel
Save