From 0ca9485deaf5fa28d25edd9aa0e57999ce376569 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 26 Jan 2019 07:27:07 -0600 Subject: [PATCH] as_base_type #136 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 1 - src/TensorFlowNET.Core/Operations/OpDefLibrary.cs | 10 +++------- src/TensorFlowNET.Core/Tensors/dtypes.cs | 8 ++++++++ src/TensorFlowNET.Core/Variables/RefVariable.cs | 4 ++-- test/TensorFlowNET.UnitTest/VariableTest.cs | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 0bac4c8f..963a1a40 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; -using TF_DataType = Tensorflow.DataType; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index e0e68929..f4b5b08f 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -72,12 +72,8 @@ namespace Tensorflow } else { - var base_type = value.dtype; - // base type - if ((int)value.dtype > 100) - { - base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString()); - } + var base_type = value.dtype.as_base_dtype(); + input_types.Add(base_type); } } @@ -151,7 +147,7 @@ namespace Tensorflow public DataType _MakeType(TF_DataType v, AttrDef attr_def) { - return v.as_datatype_enum(); + return v.as_base_dtype().as_datatype_enum(); } } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index ace8b66c..fc625c44 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -24,6 +24,7 @@ namespace Tensorflow throw new NotImplementedException("as_numpy_datatype failed"); } } + public static TF_DataType as_dtype(Type type) { TF_DataType dtype = TF_DataType.DtInvalid; @@ -62,5 +63,12 @@ namespace Tensorflow return dtype; } + + public static TF_DataType as_base_dtype(this TF_DataType type) + { + return (int)type > 100 ? + (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) : + type; + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 8a9a1377..e0e66869 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -16,7 +16,7 @@ namespace Tensorflow private Operation _initializer_op; public Operation initializer => _initializer_op; - public Operation op => _initializer_op; + public Operation op => _variable.op; public string name => _variable.name; @@ -77,7 +77,7 @@ namespace Tensorflow var shape = _initial_value.shape; dtype = _initial_value.dtype; - _variable = gen_state_ops.variable_v2(shape, dtype, name); + _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name); } // Manually overrides the variable's shape with the initial value's. diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 31ff8b40..9083045e 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Add() { - var x = tf.Variable(0, name: "x"); + var x = tf.Variable(10, name: "x"); var model = tf.global_variables_initializer();