Browse Source

fix dtypes.as_numpy_datatype for bool

tags/v0.9
Oceania2018 6 years ago
parent
commit
1edf86a07f
3 changed files with 19 additions and 8 deletions
  1. +6
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  3. +11
    -4
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 6
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -16,14 +16,16 @@ namespace Tensorflow
name = scope; name = scope;
var xs = ops.convert_n_to_tensor(data); var xs = ops.convert_n_to_tensor(data);
condition = ops.convert_to_tensor(condition, name: "Condition"); condition = ops.convert_to_tensor(condition, name: "Condition");
Func<Operation[]> true_assert = () => new Operation[]
Func<Operation[]> true_assert = () =>
{ {
gen_logging_ops._assert(condition, data, summarize, name: "Assert")
var assert = gen_logging_ops._assert(condition, data, summarize, name: "Assert");
return new Operation[] { assert };
}; };


Func<Operation[]> false_assert = () => new Operation[]
Func<Operation[]> false_assert = () =>
{ {
gen_control_flow_ops.no_op()
var op = gen_control_flow_ops.no_op();
return new Operation[] { op };
}; };


var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard");


+ 2
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -10,6 +10,8 @@ namespace Tensorflow
{ {
switch (type) switch (type)
{ {
case TF_DataType.TF_BOOL:
return typeof(bool);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
return typeof(int); return typeof(int);
case TF_DataType.TF_INT16: case TF_DataType.TF_INT16:


+ 11
- 4
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -47,16 +47,23 @@ namespace Tensorflow
var tensor_dtype = tensor.Dtype.as_numpy_dtype(); var tensor_dtype = tensor.Dtype.as_numpy_dtype();


if (tensor.TensorContent.Length > 0) if (tensor.TensorContent.Length > 0)
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype)
.reshape(shape);
{
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape);
}
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
; ;
else if (tensor.Dtype == DataType.DtFloat) else if (tensor.Dtype == DataType.DtFloat)
; ;
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
{
if (tensor.IntVal.Count == 1) if (tensor.IntVal.Count == 1)
return np.repeat(np.array(tensor.IntVal[0]), Convert.ToInt32(num_elements))
.reshape(shape);
return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape);
}
else if (tensor.Dtype == DataType.DtBool)
{
if (tensor.BoolVal.Count == 1)
return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape);
}


throw new NotImplementedException("MakeNdarray"); throw new NotImplementedException("MakeNdarray");
} }


Loading…
Cancel
Save