| @@ -3,7 +3,6 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using TF_DataType = Tensorflow.DataType; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -72,12 +72,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var base_type = value.dtype; | |||||
| // base type | |||||
| if ((int)value.dtype > 100) | |||||
| { | |||||
| base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString()); | |||||
| } | |||||
| var base_type = value.dtype.as_base_dtype(); | |||||
| input_types.Add(base_type); | input_types.Add(base_type); | ||||
| } | } | ||||
| } | } | ||||
| @@ -151,7 +147,7 @@ namespace Tensorflow | |||||
| public DataType _MakeType(TF_DataType v, AttrDef attr_def) | public DataType _MakeType(TF_DataType v, AttrDef attr_def) | ||||
| { | { | ||||
| return v.as_datatype_enum(); | |||||
| return v.as_base_dtype().as_datatype_enum(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("as_numpy_datatype failed"); | throw new NotImplementedException("as_numpy_datatype failed"); | ||||
| } | } | ||||
| } | } | ||||
| public static TF_DataType as_dtype(Type type) | public static TF_DataType as_dtype(Type type) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | TF_DataType dtype = TF_DataType.DtInvalid; | ||||
| @@ -62,5 +63,12 @@ namespace Tensorflow | |||||
| return dtype; | return dtype; | ||||
| } | } | ||||
| public static TF_DataType as_base_dtype(this TF_DataType type) | |||||
| { | |||||
| return (int)type > 100 ? | |||||
| (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) : | |||||
| type; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow | |||||
| private Operation _initializer_op; | private Operation _initializer_op; | ||||
| public Operation initializer => _initializer_op; | public Operation initializer => _initializer_op; | ||||
| public Operation op => _initializer_op; | |||||
| public Operation op => _variable.op; | |||||
| public string name => _variable.name; | public string name => _variable.name; | ||||
| @@ -77,7 +77,7 @@ namespace Tensorflow | |||||
| var shape = _initial_value.shape; | var shape = _initial_value.shape; | ||||
| dtype = _initial_value.dtype; | dtype = _initial_value.dtype; | ||||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
| _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name); | |||||
| } | } | ||||
| // Manually overrides the variable's shape with the initial value's. | // Manually overrides the variable's shape with the initial value's. | ||||
| @@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Add() | public void Add() | ||||
| { | { | ||||
| var x = tf.Variable(0, name: "x"); | |||||
| var x = tf.Variable(10, name: "x"); | |||||
| var model = tf.global_variables_initializer(); | var model = tf.global_variables_initializer(); | ||||