| @@ -42,8 +42,8 @@ class GetRefParamEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> x; | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| @@ -128,7 +128,8 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { | |||
| std::vector<std::pair<std::string, ValuePtr>> key_values; | |||
| for (auto item : dict_values) { | |||
| if (!py::isinstance<py::str>(item.first)) { | |||
| MS_LOG(EXCEPTION) << "The key of dict is only support str."; | |||
| MS_LOG(ERROR) << "The key of dict is only support str."; | |||
| return false; | |||
| } | |||
| std::string key = py::str(item.first); | |||
| ValuePtr out = nullptr; | |||
| @@ -158,7 +159,7 @@ void ConvertDataClass(py::object obj, ValuePtr *const data) { | |||
| } | |||
| bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { | |||
| MS_LOG(DEBUG) << "Converting primitive object"; | |||
| MS_LOG(DEBUG) << "Converting primitive object" << use_signature; | |||
| // need check the primitive is class type or instance | |||
| auto obj_type = data_converter::GetObjType(obj); | |||
| @@ -184,6 +185,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = | |||
| } else { | |||
| *data = primitive; | |||
| } | |||
| MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString(); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -389,12 +391,12 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python | |||
| std::string obj_id = results[0] + python_mod_get_parse_method; | |||
| std::string obj_key = results[1]; | |||
| FuncGraphPtr func_graph = nullptr; | |||
| Any value = Any(); | |||
| ValuePtr value = nullptr; | |||
| bool is_cache = data_converter::GetObjectValue(obj_id, &value); | |||
| if (is_cache) { | |||
| if (value.is<FuncGraphPtr>()) { | |||
| if (value && value->isa<FuncGraph>()) { | |||
| MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; | |||
| func_graph = value.cast<FuncGraphPtr>(); | |||
| func_graph = value->cast<FuncGraphPtr>(); | |||
| return func_graph; | |||
| } | |||
| } | |||
| @@ -415,10 +417,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python | |||
| return func_graph; | |||
| } | |||
| namespace data_converter { | |||
| static std::unordered_map<std::string, Any> object_map_ = std::unordered_map<std::string, Any>(); | |||
| static std::unordered_map<std::string, ValuePtr> object_map_; | |||
| static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_ = | |||
| std::unordered_map<std::string, std::vector<FuncGraphPtr>>(); | |||
| static std::unordered_map<std::string, std::vector<FuncGraphPtr>> object_graphs_map_; | |||
| void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { | |||
| object_graphs_map_[obj_key].push_back(data); | |||
| @@ -430,8 +431,8 @@ const std::unordered_map<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() | |||
| return object_graphs_map_; | |||
| } | |||
| void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } | |||
| bool GetObjectValue(const std::string &obj_key, Any *const data) { | |||
| void CacheObjectValue(const std::string &obj_key, const ValuePtr &data) { object_map_[obj_key] = data; } | |||
| bool GetObjectValue(const std::string &obj_key, ValuePtr *const data) { | |||
| if (object_map_.count(obj_key)) { | |||
| *data = object_map_[obj_key]; | |||
| return true; | |||
| @@ -32,8 +32,8 @@ namespace mindspore { | |||
| namespace parse { | |||
| // data convert for parse | |||
| namespace data_converter { | |||
| void CacheObjectValue(const std::string &obj_key, const Any &data); | |||
| bool GetObjectValue(const std::string &obj_key, Any *const data); | |||
| void CacheObjectValue(const std::string &obj_key, const ValuePtr &data); | |||
| bool GetObjectValue(const std::string &obj_key, ValuePtr *const data); | |||
| void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); | |||
| @@ -82,6 +82,9 @@ std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecialize | |||
| if (iter != specializations_.end()) { | |||
| return iter->second; | |||
| } | |||
| if (context->func_graph()) { | |||
| MS_LOG(EXCEPTION) << "Specialize inner error"; | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -539,8 +542,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; | |||
| // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early | |||
| if (status == kSpecializeFindUniqueArgvalPoly || | |||
| (func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || | |||
| func->abstract()->isa<PartialAbstractClosure>()))) { | |||
| (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) { | |||
| auto wrapped_node = BuildSpecializedParameterNode(new_node); | |||
| new_inputs[0] = wrapped_node; | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "ir/manager.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "abstract/abstract_function.h" | |||
| namespace mindspore { | |||
| /* | |||
| @@ -48,6 +49,11 @@ FuncGraph::FuncGraph() | |||
| debug_info_ = std::make_shared<GraphDebugInfo>(); | |||
| } | |||
| abstract::AbstractBasePtr FuncGraph::ToAbstract() { | |||
| auto temp_context = abstract::AnalysisContext::DummyContext(); | |||
| return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context); | |||
| } | |||
| AnfNodePtr FuncGraph::output() const { | |||
| // If return value is set, return should have two inputs. | |||
| if (return_ != nullptr && return_->inputs().size() == 2) { | |||
| @@ -149,6 +149,7 @@ class FuncGraph : public FuncGraphBase { | |||
| // get the graph's abstract | |||
| abstract::AbstractFunctionPtr abstract(); | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| // return the graph's output, or nullptr if not yet deduced | |||
| AnfNodePtr output() const; | |||
| @@ -19,9 +19,15 @@ | |||
| #include "ir/meta_func_graph.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "abstract/abstract_function.h" | |||
| // namespace to support intermediate representation definition | |||
| namespace mindspore { | |||
| abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() { | |||
| return std::make_shared<abstract::MetaFuncGraphAbstractClosure>(shared_from_base<MetaFuncGraph>()); | |||
| } | |||
| FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| @@ -49,7 +49,7 @@ class MetaFuncGraph : public FuncGraphBase { | |||
| virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { | |||
| return args_spec_list; | |||
| } | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; } | |||
| // Generate a Graph for the given abstract arguments. | |||
| @@ -17,8 +17,15 @@ | |||
| #include "ir/primitive.h" | |||
| #include <utility> | |||
| #include "abstract/abstract_function.h" | |||
| namespace mindspore { | |||
| abstract::AbstractBasePtr Primitive::ToAbstract() { | |||
| return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr); | |||
| } | |||
| bool Primitive::operator==(const Value &other) const { | |||
| if (other.isa<Primitive>()) { | |||
| auto other_prim = static_cast<const Primitive &>(other); | |||
| @@ -57,7 +57,7 @@ class Primitive : public Named { | |||
| record_evaluate_add_attr_(false) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| abstract::AbstractBasePtr ToAbstract(); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||
| std::string ToString() const override { return name(); } | |||
| void BeginRecordAddAttr() { | |||
| @@ -102,7 +102,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): | |||
| def construct(self, data, label): | |||
| out = self._backbone(data) | |||
| label = F.mixed_precision_cast(mstype.float32, label) | |||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||
| return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) | |||
| validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | |||
| if cast_model_type == mstype.float16: | |||
| @@ -25,7 +25,8 @@ from mindspore import Tensor | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| from mindspore.nn import learning_rate_schedule as lr_schedules | |||
| from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| from mindspore.ops import operations as P | |||
| from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| from ...dataset_mock import MindData | |||
| from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph | |||
| @@ -100,7 +101,7 @@ def get_config(version='base', batch_size=1): | |||
| class BertLearningRate(lr_schedules.LearningRateSchedule): | |||
| def __init__(self, decay_steps, warmup_steps=0, learning_rate=0.1, end_learning_rate=0.0001, power=1.0): | |||
| def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0): | |||
| super(BertLearningRate, self).__init__() | |||
| self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps) | |||
| self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) | |||
| @@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell): | |||
| def __init__(self, length, max_relative_position): | |||
| super(RelaPosMatrixGenerator, self).__init__() | |||
| self._length = length | |||
| self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) | |||
| self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) | |||
| self._max_relative_position = max_relative_position | |||
| self._min_relative_position = -max_relative_position | |||
| self.range_length = -length + 1 | |||
| self.tile = P.Tile() | |||
| @@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): | |||
| self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, | |||
| max_relative_position=max_relative_position) | |||
| self.reshape = P.Reshape() | |||
| self.one_hot = P.OneHot() | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.one_hot = nn.OneHot(depth=self.vocab_size) | |||
| self.shape = P.Shape() | |||
| self.gather = P.GatherV2() # index_select | |||
| self.matmul = P.BatchMatMul() | |||
| @@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): | |||
| if self.use_one_hot_embeddings: | |||
| flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) | |||
| one_hot_relative_positions_matrix = self.one_hot( | |||
| flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) | |||
| flat_relative_positions_matrix) | |||
| embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) | |||
| my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) | |||
| embeddings = self.reshape(embeddings, my_shape) | |||
| @@ -372,11 +370,11 @@ class SaturateCast(nn.Cell): | |||
| def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): | |||
| super(SaturateCast, self).__init__() | |||
| np_type = mstype.dtype_to_nptype(dst_type) | |||
| min_type = np.finfo(np_type).min | |||
| max_type = np.finfo(np_type).max | |||
| min_type = float(np.finfo(np_type).min) | |||
| max_type = float(np.finfo(np_type).max) | |||
| self.tensor_min_type = Tensor([min_type], dtype=src_type) | |||
| self.tensor_max_type = Tensor([max_type], dtype=src_type) | |||
| self.tensor_min_type = min_type | |||
| self.tensor_max_type = max_type | |||
| self.min_op = P.Minimum() | |||
| self.max_op = P.Maximum() | |||
| @@ -442,7 +440,7 @@ class BertAttention(nn.Cell): | |||
| self.has_attention_mask = has_attention_mask | |||
| self.use_relative_positions = use_relative_positions | |||
| self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) | |||
| self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) | |||
| self.reshape = P.Reshape() | |||
| self.shape_from_2d = (-1, from_tensor_width) | |||
| self.shape_to_2d = (-1, to_tensor_width) | |||
| @@ -471,7 +469,7 @@ class BertAttention(nn.Cell): | |||
| self.trans_shape = (0, 2, 1, 3) | |||
| self.trans_shape_relative = (2, 0, 1, 3) | |||
| self.trans_shape_position = (1, 2, 0, 3) | |||
| self.multiply_data = Tensor([-10000.0,], dtype=compute_type) | |||
| self.multiply_data = -10000.0 | |||
| self.batch_num = batch_size * num_attention_heads | |||
| self.matmul = P.BatchMatMul() | |||
| @@ -15,6 +15,7 @@ | |||
| """ test nn ops """ | |||
| import numpy as np | |||
| from numpy.random import normal | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| @@ -311,6 +312,7 @@ def test_op_with_arg_as_input(): | |||
| # The partial application used as argument is not supported yet | |||
| # because of the limit of inference specialize system | |||
| @pytest.mark.skip("poly in infer") | |||
| def test_partial_as_arg(): | |||
| class PartialArgNet(nn.Cell): | |||
| def __init__(self): | |||