Browse Source

change key from uint64 to int64

feature/build-system-rewrite
chenfei 4 years ago
parent
commit
b13dc1ede3
3 changed files with 9 additions and 6 deletions
  1. +4
    -4
      mindspore/ccsrc/frontend/optimizer/environ_conversion.cc
  2. +3
    -0
      mindspore/ccsrc/utils/func_graph_analyzer.cc
  3. +2
    -2
      mindspore/core/ir/tensor.cc

+ 4
- 4
mindspore/ccsrc/frontend/optimizer/environ_conversion.cc View File

@@ -115,13 +115,13 @@ TypeId GetValueType(const CNodePtr &cnode) {

AnfNodePtr GetTransformedKeyNode(const AnfNodePtr &old_key_node, SymbolicKeyConversionMap *para_symbol_map) {
const auto &symbolic_key_inst = GetValueNode<SymbolicKeyInstancePtr>(old_key_node);
uint64_t transformed_key = 0;
int64_t transformed_key = 0;
auto &symbolic_key_map = *para_symbol_map;
auto iter = symbolic_key_map.find(symbolic_key_inst);
if (iter != symbolic_key_map.end()) {
transformed_key = iter->second;
} else {
static uint64_t key_counter = 0;
static int64_t key_counter = 0;
transformed_key = ++key_counter;
symbolic_key_map.emplace(std::make_pair(symbolic_key_inst, transformed_key));
}
@@ -163,12 +163,12 @@ bool EnvironConversion(const pipeline::ResourcePtr &resource) {
prim = std::make_shared<Primitive>(prim::kEnvironGet);
}
MS_EXCEPTION_IF_NULL(prim);
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
prim->set_attr(attr_name, MakeValue(static_cast<int64_t>(type_id)));
transformed_prim_node = NewValueNode(prim);
txn.SetEdge(node, kPrimitiveOffset, transformed_prim_node);
}
} else {
prim->set_attr(attr_name, MakeValue(static_cast<int>(type_id)));
prim->set_attr(attr_name, MakeValue(static_cast<int64_t>(type_id)));
}
// Abstract of Environ & Value will be set by previous TransformNodeAbstract function.
// Key


+ 3
- 0
mindspore/ccsrc/utils/func_graph_analyzer.cc View File

@@ -492,6 +492,9 @@ ValueGetterPtr DirectValueGetter::Visit(int64_t index, const std::shared_ptr<Has
return shared_from_this();
}
std::vector<FuncClosurePtr> DirectValueGetter::GetFuncGraphs() {
if (!IsValueNode<FuncGraph>(anf_node_)) {
MS_LOG(EXCEPTION) << "Expect a func graph value node, but got an illegal value node:" << anf_node_->DebugString();
}
if (func_graphs_.empty()) {
func_graphs_.emplace_back(std::make_shared<FuncClosure>(GetValueNode<FuncGraphPtr>(anf_node_),
std::vector<size_t>(), std::vector<CNodePtr>()));


+ 2
- 2
mindspore/core/ir/tensor.cc View File

@@ -536,7 +536,7 @@ Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId sr
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}

Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt64), {static_cast<int>(input.size())}),
data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
id_(MakeId()) {}

@@ -546,7 +546,7 @@ Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
id_(MakeId()) {}

Tensor::Tensor(int64_t input, const TypePtr &data_type)
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt64), {}),
data_(MakeTensorData(data_type_, {}, input)),
id_(MakeId()) {}



Loading…
Cancel
Save