Browse Source

fix code review

tags/v1.4.0
lianliguang 4 years ago
parent
commit
1b41219c46
13 changed files with 83 additions and 142 deletions
  1. +0
    -3
      mindspore/ccsrc/backend/session/kernel_graph.cc
  2. +2
    -3
      mindspore/ccsrc/pipeline/jit/parse/resolve.h
  3. +1
    -2
      mindspore/ccsrc/pipeline/jit/resource_base.h
  4. +13
    -41
      mindspore/core/abstract/abstract_function.cc
  5. +4
    -0
      mindspore/core/abstract/abstract_value.cc
  6. +2
    -5
      mindspore/core/abstract/analysis_context.cc
  7. +17
    -21
      mindspore/core/abstract/prim_arrays.cc
  8. +1
    -0
      mindspore/core/abstract/prim_maths.cc
  9. +6
    -32
      mindspore/core/abstract/prim_nn.cc
  10. +17
    -10
      mindspore/core/abstract/prim_others.cc
  11. +2
    -2
      mindspore/core/abstract/prim_statement.cc
  12. +10
    -2
      mindspore/core/abstract/prim_structures.cc
  13. +8
    -21
      mindspore/core/abstract/utils.cc

+ 0
- 3
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -556,8 +556,6 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
auto new_parameter = NewParameter(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);

// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
if (parameter != nullptr) {
new_parameter->set_name(parameter->name());
@@ -574,7 +572,6 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
ParameterPtr new_parameter = add_parameter();
new_parameter->set_abstract(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());


+ 2
- 3
mindspore/ccsrc/pipeline/jit/parse/resolve.h View File

@@ -130,9 +130,6 @@ class SymbolResolver {

AnfNodePtr resolved_node() { return resolved_node_; }

// Resolve result
py::object result_;

private:
// namespace where the symbol locates
NameSpacePtr namespace_;
@@ -140,6 +137,8 @@ class SymbolResolver {
SymbolPtr symbol_;
// the node that has been resolved
AnfNodePtr resolved_node_;
// Resolve result
py::object result_;
};
using SymbolResolverPtr = std::shared_ptr<SymbolResolver>;
// Resolve symbol in namespace.


+ 1
- 2
mindspore/ccsrc/pipeline/jit/resource_base.h View File

@@ -51,10 +51,9 @@ class ResourceBase {

bool HasResult(const std::string &key) const { return results_.count(key) != 0; }

std::unordered_map<std::string, Any> results_;

protected:
FuncGraphManagerPtr manager_;
std::unordered_map<std::string, Any> results_;
};

using ResourceBasePtr = std::shared_ptr<pipeline::ResourceBase>;


+ 13
- 41
mindspore/core/abstract/abstract_function.cc View File

@@ -22,7 +22,6 @@ namespace mindspore {
namespace abstract {
class Evaluator;
class AnalysisEngine;

AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) {
if (func_list.size() == 1) {
return func_list[0];
@@ -102,7 +101,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
return std::make_shared<AbstractFuncUnion>(this_func, other);
}
auto other_union = dyn_cast<AbstractFuncUnion>(other);
MS_EXCEPTION_IF_NULL(other);
MS_EXCEPTION_IF_NULL(other_union);
if (other_union->IsSuperSet(this_func)) {
return other;
}
@@ -110,7 +109,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) {
}

void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const {
for (AbstractFuncAtomPtr poss : func_list_) {
for (const AbstractFuncAtomPtr &poss : func_list_) {
visit_func(poss);
}
}
@@ -123,15 +122,12 @@ bool AbstractFuncUnion::operator==(const AbstractFunction &other) const {
if (func_list_.size() != other_union->func_list_.size()) {
return false;
}
if (func_list_ == other_union->func_list_) {
return true;
}
return false;
return func_list_ == other_union->func_list_;
}

std::size_t AbstractFuncUnion::hash() const {
std::size_t hash_sum = 0;
for (auto f : func_list_) {
for (const auto &f : func_list_) {
MS_EXCEPTION_IF_NULL(f);
hash_sum = hash_combine(hash_sum, f->hash());
}
@@ -144,10 +140,7 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
}
auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other);
MS_EXCEPTION_IF_NULL(prim_);
if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) {
return true;
}
return false;
return (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id());
}

std::size_t PrimitiveAbstractClosure::hash() const {
@@ -165,11 +158,8 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
return false;
}
auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
tracking_id() == other_fg->tracking_id()) {
return true;
}
return false;
return func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
tracking_id() == other_fg->tracking_id();
}

std::size_t FuncGraphAbstractClosure::hash() const {
@@ -195,10 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con
return false;
}
auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) {
return true;
}
return false;
return meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id();
}

std::size_t MetaFuncGraphAbstractClosure::hash() const {
@@ -226,10 +213,7 @@ bool PartialAbstractClosure::operator==(const AbstractFunction &other) const {
if (args_spec_list_.size() != other_partial->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_partial->args_spec_list_) {
return true;
}
return false;
return args_spec_list_ == other_partial->args_spec_list_;
}

std::size_t PartialAbstractClosure::hash() const {
@@ -255,10 +239,7 @@ bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) cons
return false;
}
auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other);
if (fn_ == other_transformed->fn_) {
return true;
}
return false;
return fn_ == other_transformed->fn_;
}

std::size_t JTransformedAbstractClosure::hash() const {
@@ -278,10 +259,7 @@ bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_virtual->args_spec_list_) {
return true;
}
return false;
return args_spec_list_ == other_virtual->args_spec_list_;
}

std::size_t VirtualAbstractClosure::hash() const {
@@ -319,10 +297,7 @@ bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) co
if (args_spec_list_.size() != other_typed->args_spec_list_.size()) {
return false;
}
if (args_spec_list_ == other_typed->args_spec_list_) {
return true;
}
return false;
return args_spec_list_ == other_typed->args_spec_list_;
}

std::size_t TypedPrimitiveAbstractClosure::hash() const {
@@ -346,10 +321,7 @@ std::string TypedPrimitiveAbstractClosure::ToString() const {
}

bool DummyAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<DummyAbstractClosure>()) {
return false;
}
return true;
return !other.isa<DummyAbstractClosure>();
}
} // namespace abstract
} // namespace mindspore

+ 4
- 0
mindspore/core/abstract/abstract_value.cc View File

@@ -238,6 +238,7 @@ std::string AbstractError::ToString() const {
std::ostringstream buffer;
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
MS_EXCEPTION_IF_NULL(node_);
buffer << type_name() << "("
<< "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
return buffer.str();
@@ -594,6 +595,8 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const {
if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
is_value_equal = true;
}
MS_EXCEPTION_IF_NULL(element_);
MS_EXCEPTION_IF_NULL(other.element_);
return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
}

@@ -912,6 +915,7 @@ AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
if (other_jtagged == nullptr) {
AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other);
}
MS_EXCEPTION_IF_NULL(element_);
auto joined_elem = element_->Join(other_jtagged->element_);
return std::make_shared<AbstractJTagged>(joined_elem);
}


+ 2
- 5
mindspore/core/abstract/analysis_context.cc View File

@@ -85,7 +85,7 @@ AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &f
oss << "nullptr";
}
oss << " extant context list: {";
for (auto iter : extant_context_cache_) {
for (const auto &iter : extant_context_cache_) {
if (iter.first == nullptr) {
oss << " [graph: nullptr";
} else {
@@ -108,10 +108,7 @@ AnalysisContextPtr AnalysisContext::DummyContext() {
}

bool AnalysisContext::IsDummyContext() {
if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) {
return true;
}
return false;
return parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty();
}

const AnalysisContextPtr kDummyAnalysisContext =


+ 17
- 21
mindspore/core/abstract/prim_arrays.cc View File

@@ -75,6 +75,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });

ShapeVector res = BroadcastShape(shp_x, shp_y);
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
if (res.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
@@ -115,16 +116,17 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
(void)CheckDtypeSame(op_name, tensor_base, tensor);
(void)CheckShapeSame(op_name, tensor_base, tensor);
}

auto element = tensor_base->element();
MS_EXCEPTION_IF_NULL(element);
primitive->set_attr("N", MakeValue(SizeToLong(tuple_len)));
primitive->set_attr("T", tensor_base->element()->BuildType());
primitive->set_attr("T", element->BuildType());

AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
MS_EXCEPTION_IF_NULL(ret);
auto ret_shape_ptr = ret->shape();
MS_EXCEPTION_IF_NULL(ret_shape_ptr);
auto ret_shape = ret->shape()->shape();
(void)ret_shape.insert(ret_shape.begin() + axis_value, tuple_len);
auto ret_shape = ret_shape_ptr->shape();
(void)ret_shape.insert(ret_shape.begin() + axis_value, SizeToLong(tuple_len));
ret->set_shape(std::make_shared<Shape>(ret_shape));
return ret;
}
@@ -137,7 +139,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);

auto shape = input->shape();
if (shape->shape().size() != 1) {
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().empty()) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
ShapeVector ids_shape = {Shape::SHP_ANY};
@@ -205,6 +208,10 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
CheckArgsSize(op_name + " dout", dout->elements(), size_expected);
auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0);
auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1);
auto ids_shape = ids->shape();
auto ids_idx_shape = ids_idx->shape();
MS_EXCEPTION_IF_NULL(ids_shape);
MS_EXCEPTION_IF_NULL(ids_idx_shape);
if (ids->shape()->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1.";
}
@@ -278,7 +285,6 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri
const size_t size_expected = 3;
CheckArgsSize(op_name, args_spec_list, size_expected);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(segment_ids);
@@ -426,12 +432,10 @@ AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitiveP
const size_t size_expected = 5;
CheckArgsSize(op_name, args_spec_list, size_expected);
auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(hash_map);
MS_EXCEPTION_IF_NULL(hash_map->shape());

auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices_shp);

ShapeVector shape;
@@ -466,12 +470,10 @@ AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const Primiti
CheckArgsSize(op_name, args_spec_list, size_expected);
auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto cache_table_shp = cache_table->shape();
MS_EXCEPTION_IF_NULL(cache_table);
MS_EXCEPTION_IF_NULL(cache_table_shp);

auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto swap_cache_idx_shp = swap_cache_idx->shape();
MS_EXCEPTION_IF_NULL(swap_cache_idx);
MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);

auto cache_table_shape = cache_table_shp->shape();
@@ -501,12 +503,6 @@ AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices->shape());

ShapeVector shape;
shape.emplace_back(1);
@@ -520,7 +516,6 @@ AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const Primitive
const std::string op_name = primitive->name();
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto input_x_shp = input_x->shape();
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x_shp);

ShapeVector shape;
@@ -648,6 +643,7 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv

MS_LOG(INFO) << "InferImplDynamicAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
@@ -669,12 +665,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
const std::string op_name = primitive->name();
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
MS_EXCEPTION_IF_NULL(params);
MS_EXCEPTION_IF_NULL(params_shp);
auto params_shape = params_shp->shape();
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices_shp);
auto indices_shape = indices_shp->shape();
auto indices_max_shape = indices_shp->max_shape();
@@ -726,8 +720,8 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input->shape());
auto shape = input->shape()->shape();

bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
ShapeVector tensor_shp({static_cast<int64_t>(shape.size())});
if (has_dyn_shape) {
@@ -750,6 +744,7 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto input_shp = input->shape()->shape();
ValuePtr perm = primitive->GetAttr("perm");
MS_EXCEPTION_IF_NULL(perm);
auto perm_val = perm->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(perm_val);
auto perm_val_data = perm_val->value();
@@ -788,6 +783,7 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
x_min_shape = x_shape;
}
ValuePtr sh = primitive->GetAttr("shape");
MS_EXCEPTION_IF_NULL(sh);
auto reshape_value_tuple = sh->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(reshape_value_tuple);
auto reshape_tuple = reshape_value_tuple->value();
@@ -866,7 +862,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr
ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank)));
int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value();
int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num"));
if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos]
<< " must be divisible by output_num = " << output_num_value;


+ 1
- 0
mindspore/core/abstract/prim_maths.cc View File

@@ -91,6 +91,7 @@ int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {

void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
bool keep_dims_value) {
MS_EXCEPTION_IF_NULL(axis);
if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
auto axis_ptr_list =
axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();


+ 6
- 32
mindspore/core/abstract/prim_nn.cc View File

@@ -57,12 +57,12 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
int64_t h_input = input_shape->shape()[H_INDEX];
int64_t w_input = input_shape->shape()[W_INDEX];

int64_t window = primitive->GetAttr("window")->cast<Int64ImmPtr>()->value();
int64_t stride = primitive->GetAttr("stride")->cast<Int64ImmPtr>()->value();
int64_t padding = primitive->GetAttr("pad")->cast<Int64ImmPtr>()->value();
int64_t nan_opt = primitive->GetAttr("nan_opt")->cast<Int64ImmPtr>()->value();
int64_t data_mode = primitive->GetAttr("data_mode")->cast<Int64ImmPtr>()->value();
int64_t ceil_mode = primitive->GetAttr("ceil_mode")->cast<Int64ImmPtr>()->value();
int64_t window = GetValue<int64_t>(primitive->GetAttr("window"));
int64_t stride = GetValue<int64_t>(primitive->GetAttr("stride"));
int64_t padding = GetValue<int64_t>(primitive->GetAttr("pad"));
int64_t nan_opt = GetValue<int64_t>(primitive->GetAttr("nan_opt"));
int64_t data_mode = GetValue<int64_t>(primitive->GetAttr("data_mode"));
int64_t ceil_mode = GetValue<int64_t>(primitive->GetAttr("ceil_mode"));

if (stride <= 0) {
MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0";
@@ -124,32 +124,6 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP
return ret;
}

void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
// check dimension, x > 1, others equal 1
const std::string op_name = primitive->name();
for (std::size_t i = 0; i < args_spec_list.size(); ++i) {
AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack());
if (arg_shape == nullptr) {
MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString();
}

if (i == 0) {
if (arg_shape->shape().size() < 2) {
MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
<< "] should be TensorShape with dimension greater than 1, but shape: "
<< arg_shape->ToString();
}
continue;
}

if (arg_shape->shape().size() != 1) {
MS_LOG(EXCEPTION) << op_name << " shape of args[" << i
<< "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString();
}
}
}

AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(x, gamma, beta, mean, variance).


+ 17
- 10
mindspore/core/abstract/prim_others.cc View File

@@ -100,6 +100,7 @@ AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &p

AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) {
ValuePtr name_value = prim->GetAttr("tag");
MS_EXCEPTION_IF_NULL(name_value);
auto name = name_value->cast<StringImmPtr>();
if (name == nullptr) {
MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << ".";
@@ -132,7 +133,9 @@ AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr
if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString();
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
MS_EXCEPTION_IF_NULL(abs_ref);
return abs_ref->ref();
}

AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &,
@@ -146,7 +149,9 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP
if (type->type_id() != kObjectTypeRef) {
return args_spec_list[0];
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>();
MS_EXCEPTION_IF_NULL(abs_ref);
return abs_ref->ref();
}

AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@@ -185,6 +190,7 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
return args_spec_list[0]->Broaden();
}

@@ -213,14 +219,16 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
<< values_shp[0] << ", but got " << indices_shp[0];
}

for (auto elem_type : dense_shape->ElementsType()) {
for (const auto &elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
auto dense_shape_value = dense_shape->BuildValue();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_valuetuple);
auto shp = dense_shape_valuetuple->value();
ShapeVector dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int64_t {
@@ -229,7 +237,7 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv
});
if (dense_shape_vec.size() != values_shp.size()) {
MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values "
<< values_shp.size() << ", but got " << dense_shape_value->size();
<< values_shp.size() << ", but got " << dense_shape_valuetuple->size();
}
for (size_t i = 0; i < dense_shape_vec.size(); i++) {
if (dense_shape_vec[i] < 0) {
@@ -321,7 +329,7 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
<< values_shp[0] << ", but got " << indices_shp[0];
}

for (auto elem_type : dense_shape->ElementsType()) {
for (const auto &elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
@@ -502,9 +510,8 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
MS_EXCEPTION_IF_NULL(input_x);
auto attr = primitive->GetAttr("dst_type");
if (attr == nullptr) {
auto input_dtype = args_spec_list[1];
MS_EXCEPTION_IF_NULL(input_dtype);
attr = input_dtype->BuildValue();
auto type_abs = CheckArg<AbstractType>(op_name, args_spec_list, 1);
attr = type_abs->BuildValue();
MS_EXCEPTION_IF_NULL(attr);
primitive->set_attr("dst_type", attr);
}


+ 2
- 2
mindspore/core/abstract/prim_statement.cc View File

@@ -74,7 +74,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
abstract::CheckArgsSize(op_name, args_spec_list, kSwitchLayerInputNum);
auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto &input_shape = index->shape()->shape();
if (input_shape.size() != 0) {
if (!input_shape.empty()) {
MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size()
<< " dimension tensor";
}
@@ -86,7 +86,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
AbstractBasePtrList branches = branches_abs->elements();
const size_t maximum_layer_num = 1000;
if (branches.size() < 1 || branches.size() > maximum_layer_num) {
if (branches.empty() || branches.size() > maximum_layer_num) {
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
<< branches.size() << " branches.";
}


+ 10
- 2
mindspore/core/abstract/prim_structures.cc View File

@@ -70,6 +70,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);

ValuePtr keyPtr = key->BuildValue();
MS_EXCEPTION_IF_NULL(keyPtr);
if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
}
@@ -86,6 +87,7 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);

ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
@@ -110,6 +112,7 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
slice_args.push_back(args_spec_list[index]);
} else if (args_spec_list[index]->isa<AbstractScalar>()) {
ValuePtr scalar_value = args_spec_list[index]->cast<AbstractScalarPtr>()->BuildValue();
MS_EXCEPTION_IF_NULL(scalar_value);
if (scalar_value->isa<IntergerImm>()) {
slice_args.push_back(args_spec_list[index]);
} else if (scalar_value->isa<BoolImm>()) {
@@ -122,8 +125,9 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
} else if (args_spec_list[index]->isa<AbstractTensor>()) {
auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
TypePtr tensor_dtype = arg->element()->BuildType();

auto value = arg->BuildValue()->cast<tensor::TensorPtr>();
auto build_value = arg->BuildValue();
MS_EXCEPTION_IF_NULL(build_value);
auto value = build_value->cast<tensor::TensorPtr>();
if (value == nullptr) {
MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor.";
}
@@ -159,6 +163,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

ValuePtr index_value = index->BuildValue();
MS_EXCEPTION_IF_NULL(index_value);
if (!index_value->isa<Int64Imm>()) {
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
// and continue
@@ -191,6 +196,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

ValuePtr index_value = index->BuildValue();
MS_EXCEPTION_IF_NULL(index_value);
if (!index_value->isa<Int64Imm>()) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
<< index_value->ToString();
@@ -237,6 +243,7 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
@@ -259,6 +266,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);

ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}


+ 8
- 21
mindspore/core/abstract/utils.cc View File

@@ -68,7 +68,7 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c
continue;
}
if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
@@ -77,7 +77,7 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c
continue;
}
if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
@@ -86,11 +86,11 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c
continue;
}
// both shapes contains dynamic shape
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
@@ -109,10 +109,10 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
// lengths of two shapes are not same, join failed
if (shape1->shape().size() != shape2->shape().size()) {
// special case: shape(1), shape() -> shape(1)
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) {
return shape1;
}
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) {
return shape2;
}
return nullptr;
@@ -138,7 +138,7 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
}

AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() < 1) {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
<< ".";
}
@@ -205,7 +205,7 @@ bool CheckType(const TypePtr &expected_type, const TypePtr &x) {

TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
MS_EXCEPTION_IF_NULL(predicate);
for (auto arg_type : args_type_list) {
for (const auto &arg_type : args_type_list) {
MS_EXCEPTION_IF_NULL(arg_type);
if (!CheckType(predicate, arg_type)) {
MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
@@ -294,19 +294,6 @@ ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
return shp;
}

ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x,
const AbstractTensorPtr &tensor_y) {
mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape();
mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape();
// if is the same shape ,just return the x_shape
if (*tensor_x_shape == *tensor_y_shape) {
return tensor_x_shape;
}
auto x_shape = tensor_x_shape->shape();
auto y_shape = tensor_y_shape->shape();
return std::make_shared<Shape>(RealBroadcast(op, x_shape, y_shape));
}

size_t TypeIdSize(const TypeId data_type) {
const size_t unsupported_type_error = 0;
auto iter = type_map.find(data_type);


Loading…
Cancel
Save