Browse Source

fix make_tensor_proto when dtype is different.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
56a25a15c5
3 changed files with 20 additions and 4 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  2. +15
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  3. +4
    -2
      src/TensorFlowNET.Core/ops.cs

+ 1
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -245,7 +245,7 @@ namespace Tensorflow
-1 => "<unknown>", -1 => "<unknown>",
0 => "()", 0 => "()",
1 => $"({dims[0]},)", 1 => $"({dims[0]},)",
_ => $"{string.Join(", ", _dims).Replace("-1", "None")}"
_ => $"({string.Join(", ", _dims).Replace("-1", "None")})"
}; };
} }
} }

+ 15
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -110,7 +110,21 @@ namespace Tensorflow
if (values is TensorProto tp) if (values is TensorProto tp)
return tp; return tp;


dtype = values.GetDataType();
var origin_dtype = values.GetDataType();
if (dtype == TF_DataType.DtInvalid)
dtype = origin_dtype;
else if(origin_dtype != dtype)
{
var new_system_dtype = dtype.as_system_dtype();
if (values is long[] long_values)
{
if (dtype == TF_DataType.TF_INT32)
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
else
values = Convert.ChangeType(values, new_system_dtype);
}

shape = shape ?? values.GetShape(); shape = shape ?? values.GetShape();
var tensor_proto = new TensorProto var tensor_proto = new TensorProto
{ {


+ 4
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -123,6 +123,9 @@ namespace Tensorflow
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = preferred_dtype; dtype = preferred_dtype;


if (dtype == TF_DataType.DtInvalid)
dtype = value.GetDataType();

if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode)
{ {
if (tf.executing_eagerly()) if (tf.executing_eagerly())
@@ -173,8 +176,7 @@ namespace Tensorflow
if (dtype == TF_DataType.TF_STRING) if (dtype == TF_DataType.TF_STRING)
return ret; return ret;


var original_dtype = value.GetDataType();
if (dtype != TF_DataType.DtInvalid && dtype != original_dtype)
if (dtype != ret.dtype)
ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name); ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name);


return ret; return ret;


Loading…
Cancel
Save