diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 79416930..8ed7c17e 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -16,14 +16,16 @@ namespace Tensorflow name = scope; var xs = ops.convert_n_to_tensor(data); condition = ops.convert_to_tensor(condition, name: "Condition"); - Func true_assert = () => new Operation[] + Func true_assert = () => { - gen_logging_ops._assert(condition, data, summarize, name: "Assert") + var assert = gen_logging_ops._assert(condition, data, summarize, name: "Assert"); + return new Operation[] { assert }; }; - Func false_assert = () => new Operation[] + Func false_assert = () => { - gen_control_flow_ops.no_op() + var op = gen_control_flow_ops.no_op(); + return new Operation[] { op }; }; var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index f14affbc..34d6a8f7 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -10,6 +10,8 @@ namespace Tensorflow { switch (type) { + case TF_DataType.TF_BOOL: + return typeof(bool); case TF_DataType.TF_INT32: return typeof(int); case TF_DataType.TF_INT16: diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 870db57b..a2830fc1 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -47,16 +47,23 @@ namespace Tensorflow var tensor_dtype = tensor.Dtype.as_numpy_dtype(); if (tensor.TensorContent.Length > 0) - return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype) - .reshape(shape); + { + return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape); + } else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) ; else if (tensor.Dtype == DataType.DtFloat) ; else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) + { if (tensor.IntVal.Count == 1) - return np.repeat(np.array(tensor.IntVal[0]), Convert.ToInt32(num_elements)) - .reshape(shape); + return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape); + } + else if (tensor.Dtype == DataType.DtBool) + { + if (tensor.BoolVal.Count == 1) + return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape); + } throw new NotImplementedException("MakeNdarray"); }