|
|
|
@@ -110,7 +110,21 @@ namespace Tensorflow |
|
|
|
if (values is TensorProto tp) |
|
|
|
return tp; |
|
|
|
|
|
|
|
dtype = values.GetDataType(); |
|
|
|
var origin_dtype = values.GetDataType(); |
|
|
|
if (dtype == TF_DataType.DtInvalid) |
|
|
|
dtype = origin_dtype; |
|
|
|
else if(origin_dtype != dtype) |
|
|
|
{ |
|
|
|
var new_system_dtype = dtype.as_system_dtype(); |
|
|
|
if (values is long[] long_values) |
|
|
|
{ |
|
|
|
if (dtype == TF_DataType.TF_INT32) |
|
|
|
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); |
|
|
|
} |
|
|
|
else |
|
|
|
values = Convert.ChangeType(values, new_system_dtype); |
|
|
|
} |
|
|
|
|
|
|
|
shape = shape ?? values.GetShape(); |
|
|
|
var tensor_proto = new TensorProto |
|
|
|
{ |
|
|
|
|