| @@ -245,7 +245,7 @@ namespace Tensorflow | |||||
| -1 => "<unknown>", | -1 => "<unknown>", | ||||
| 0 => "()", | 0 => "()", | ||||
| 1 => $"({dims[0]},)", | 1 => $"({dims[0]},)", | ||||
| _ => $"{string.Join(", ", _dims).Replace("-1", "None")}" | |||||
| _ => $"({string.Join(", ", _dims).Replace("-1", "None")})" | |||||
| }; | }; | ||||
| } | } | ||||
| } | } | ||||
| @@ -110,7 +110,21 @@ namespace Tensorflow | |||||
| if (values is TensorProto tp) | if (values is TensorProto tp) | ||||
| return tp; | return tp; | ||||
| dtype = values.GetDataType(); | |||||
| var origin_dtype = values.GetDataType(); | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = origin_dtype; | |||||
| else if(origin_dtype != dtype) | |||||
| { | |||||
| var new_system_dtype = dtype.as_system_dtype(); | |||||
| if (values is long[] long_values) | |||||
| { | |||||
| if (dtype == TF_DataType.TF_INT32) | |||||
| values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); | |||||
| } | |||||
| else | |||||
| values = Convert.ChangeType(values, new_system_dtype); | |||||
| } | |||||
| shape = shape ?? values.GetShape(); | shape = shape ?? values.GetShape(); | ||||
| var tensor_proto = new TensorProto | var tensor_proto = new TensorProto | ||||
| { | { | ||||
| @@ -123,6 +123,9 @@ namespace Tensorflow | |||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = preferred_dtype; | dtype = preferred_dtype; | ||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = value.GetDataType(); | |||||
| if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | ||||
| { | { | ||||
| if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
| @@ -173,8 +176,7 @@ namespace Tensorflow | |||||
| if (dtype == TF_DataType.TF_STRING) | if (dtype == TF_DataType.TF_STRING) | ||||
| return ret; | return ret; | ||||
| var original_dtype = value.GetDataType(); | |||||
| if (dtype != TF_DataType.DtInvalid && dtype != original_dtype) | |||||
| if (dtype != ret.dtype) | |||||
| ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); | ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); | ||||
| return ret; | return ret; | ||||