|
|
|
@@ -100,8 +100,8 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
param_value->set_tensor_addr(tensor_data); |
|
|
|
tensor_size = shape_size * sizeof(float); |
|
|
|
param_value->SetTensorData(tensor_data, tensor_size); |
|
|
|
} else if (type == kNumberTypeInt32) { |
|
|
|
auto tensor_data = new (std::nothrow) int[shape_size]; |
|
|
|
if (tensor_proto.int_val_size() == 1) { |
|
|
|
@@ -118,8 +118,8 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
param_value->set_tensor_addr(tensor_data); |
|
|
|
tensor_size = shape_size * sizeof(int); |
|
|
|
param_value->SetTensorData(tensor_data, tensor_size); |
|
|
|
} else if (type == kNumberTypeBool) { |
|
|
|
auto tensor_data = new (std::nothrow) int[shape_size]; |
|
|
|
if (tensor_proto.bool_val_size() == 1) { |
|
|
|
@@ -128,8 +128,8 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value |
|
|
|
tensor_data[i] = value; |
|
|
|
} |
|
|
|
} |
|
|
|
param_value->set_tensor_addr(tensor_data); |
|
|
|
tensor_size = shape_size * sizeof(int); |
|
|
|
param_value->SetTensorData(tensor_data, tensor_size); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupport dataType: " << type; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -138,7 +138,6 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value |
|
|
|
std::vector<int> param_shape(shape_vector->begin(), shape_vector->end()); |
|
|
|
param_value->set_tensor_shape(param_shape); |
|
|
|
param_value->set_tensor_type(type); |
|
|
|
param_value->set_tensor_size(tensor_size); |
|
|
|
param_value->set_format(schema::Format::Format_NHWC); |
|
|
|
parameter->set_default_param(param_value); |
|
|
|
parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); |
|
|
|
|