diff --git a/docs/source/FrontCover.md b/docs/source/FrontCover.md index e1eacc28..a8567154 100644 --- a/docs/source/FrontCover.md +++ b/docs/source/FrontCover.md @@ -1,5 +1,5 @@ -# The Definitive Guide to Tensorflow.NET -# Tensorflow.NET 权威指南 +# The Definitive Guide to TensorFlow.NET +# TensorFlow.NET 权威指南 diff --git a/docs/source/Preface.md b/docs/source/Preface.md index ad91bf5c..6cb29357 100644 --- a/docs/source/Preface.md +++ b/docs/source/Preface.md @@ -14,9 +14,9 @@ -Why do I start the Tensorflow.NET project? +Why do I start the TensorFlow.NET project? -我为什么会写Tensorflow.NET? +我为什么会写TensorFlow.NET? 再过几天就是2018年圣诞节,看着孩子一天天长大并懂事,感慨时间过得太快。IT技术更新换代比以往任何时候都更快,各种前后端技术纷纷涌现。大数据,人工智能和区块链,容器技术和微服务,分布式计算和无服务器技术,让人眼花缭乱。Amazon AI服务接口宣称不需要具有任何机器学习经验的工程师就能使用,让像我这样刚静下心来学习了两年并打算将来转行做AI架构的想法泼了一桶凉水。 diff --git a/src/TensorFlowNET.Core/APIs/tf.constant.cs b/src/TensorFlowNET.Core/APIs/tf.constant.cs index b43d611d..d60fb50e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.constant.cs +++ b/src/TensorFlowNET.Core/APIs/tf.constant.cs @@ -9,12 +9,7 @@ namespace Tensorflow { public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { - var t = constant_op.Create(nd, name, verify_shape); - /*var graph = tf.get_default_graph(); - var tensor = new Tensor(nd); - var op = graph.NewOperation("Const", name, tensor);*/ - - return t; + return constant_op.Create(nd, name, verify_shape); } } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 1ae4e951..530469ea 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -59,7 +59,7 @@ namespace Tensorflow if (!String.IsNullOrEmpty(input_arg.TypeAttr)) { - attrs[input_arg.TypeAttr] = DataType.DtFloat; + attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype; } if (input_arg.IsRef) @@ -92,7 +92,7 @@ namespace Tensorflow switch (attr_def.Type) { case "type": - attr_value.Type = _MakeType(value, attr_def); + attr_value.Type = _MakeType((TF_DataType)value, attr_def); break; case "shape": attr_value.Shape = new TensorShapeProto(); @@ -127,9 +127,9 @@ namespace Tensorflow return op; } - public DataType _MakeType(Object v, AttrDef attr_def) + public DataType _MakeType(TF_DataType v, AttrDef attr_def) { - return DataType.DtFloat; + return v.as_datatype_enum(); } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 1a52b3cf..fbff32e0 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -37,7 +37,6 @@ namespace Tensorflow private static OpDefLibrary _InitOpDefLibrary() { - // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin"); var op_list = OpList.Parser.ParseFrom(bytes); var op_def_lib = new OpDefLibrary(); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 9d5dbc21..8ebe36bf 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -17,9 +17,7 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); - var tensor = new Tensor(_op, 0, TF_DataType.TF_FLOAT); - - return tensor; + return new Tensor(_op, 0, _op.OutputType(0)); } private static OpDefLibrary _InitOpDefLibrary() diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index 3344d5af..02b7ca11 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -28,14 +28,20 @@ namespace Tensorflow var op_desc = graph.NewOperation(node_def.Op, node_def.Name); // Add inputs - if(inputs != null && inputs.Count > 0) + if(inputs != null) { - /*foreach (var op_input in inputs) + foreach (var op_input in inputs) { - c_api.TF_AddInput(op_desc, op_input._as_tf_output()); - }*/ - - c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); + bool isList = false; + if (!isList) + { + c_api.TF_AddInput(op_desc, op_input._as_tf_output()); + } + else + { + c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); + } + } } var status = new Status(); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 85343ba8..0998bb7d 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -131,6 +131,9 @@ namespace Tensorflow case TF_DataType.TF_FLOAT: result[i] = *(float*)c_api.TF_TensorData(output_values[i]); break; + case TF_DataType.TF_INT32: + result[i] = *(int*)c_api.TF_TensorData(output_values[i]); + break; default: throw new NotImplementedException("can't get output"); break; diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index dfd557ca..379eea92 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -18,6 +18,7 @@ Docs: https://tensorflownet.readthedocs.io 0.0.2.0 + 7.2 diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs new file mode 100644 index 00000000..28678d73 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + public static Tensor operator +(Tensor t1, Tensor t2) + { + return gen_math_ops.add(t1, t2); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index a6765a8f..0c955505 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -11,7 +11,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public class Tensor : IDisposable + public partial class Tensor : IDisposable { private readonly IntPtr _handle; @@ -22,7 +22,8 @@ namespace Tensorflow public object value; public int value_index { get; } - public TF_DataType dtype => _handle == IntPtr.Zero ? TF_DataType.DtInvalid : c_api.TF_TensorType(_handle); + private TF_DataType _dtype = TF_DataType.DtInvalid; + public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong dataTypeSize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; @@ -118,6 +119,7 @@ namespace Tensorflow { this.op = op; this.value_index = value_index; + this._dtype = dtype; } public TF_Output _as_tf_output() diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs new file mode 100644 index 00000000..556017f9 --- /dev/null +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples +{ + /// + /// Basic Operations example using TensorFlow library. + /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_operations.py + /// + public class BasicOperations : IExample + { + public void Run() + { + // Basic constant operations + // The value returned by the constructor represents the output + // of the Constant op. + var a = tf.constant(2); + var b = tf.constant(3); + + // Launch the default graph. + using (var sess = tf.Session()) + { + Console.WriteLine("a=2, b=3"); + Console.WriteLine($"Addition with constants: {sess.run(a + b)}"); + //Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}"); + } + } + } +} diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 9bbcd376..41fd4e54 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -11,6 +11,9 @@ namespace TensorFlowNET.Examples var assembly = Assembly.GetEntryAssembly(); foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) { + if (args.Length > 0 && !args.Contains(type.Name)) + continue; + var example = (IExample)Activator.CreateInstance(type); try @@ -22,6 +25,8 @@ namespace TensorFlowNET.Examples Console.WriteLine(ex); } } + + Console.ReadLine(); } } }