| @@ -0,0 +1,47 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Queues; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A FIFOQueue that supports batching variable-sized tensors by padding. | |||||
| /// </summary> | |||||
| /// <param name="capacity"></param> | |||||
| /// <param name="dtypes"></param> | |||||
| /// <param name="shapes"></param> | |||||
| /// <param name="names"></param> | |||||
| /// <param name="shared_name"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public PaddingFIFOQueue PaddingFIFOQueue(int capacity, | |||||
| TF_DataType[] dtypes, | |||||
| TensorShape[] shapes, | |||||
| string[] names = null, | |||||
| string shared_name = null, | |||||
| string name = "padding_fifo_queue") | |||||
| => new PaddingFIFOQueue(capacity, | |||||
| dtypes, | |||||
| shapes, | |||||
| names, | |||||
| shared_name: shared_name, | |||||
| name: name); | |||||
| } | |||||
| } | |||||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Google.Protobuf; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -194,7 +195,9 @@ namespace Tensorflow | |||||
| if (attrs.ContainsKey(key)) | if (attrs.ContainsKey(key)) | ||||
| { | { | ||||
| attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); | attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); | ||||
| } else { | |||||
| } | |||||
| else | |||||
| { | |||||
| if (attr_def.DefaultValue == null) | if (attr_def.DefaultValue == null) | ||||
| { | { | ||||
| throw new TypeError("Missing required positional argument " + key); | throw new TypeError("Missing required positional argument " + key); | ||||
| @@ -311,6 +314,16 @@ namespace Tensorflow | |||||
| input_types.AddRange(base_types); | input_types.AddRange(base_types); | ||||
| } | } | ||||
| public ByteString _MakeStr(string value, AttrDef attr_def) | |||||
| { | |||||
| return ByteString.CopyFromUtf8(value ?? string.Empty); | |||||
| } | |||||
| public TensorShapeProto _MakeShape(TensorShape shape, AttrDef attr_def) | |||||
| { | |||||
| return shape.as_proto(); | |||||
| } | |||||
| public DataType _MakeType(TF_DataType v, AttrDef attr_def) | public DataType _MakeType(TF_DataType v, AttrDef attr_def) | ||||
| { | { | ||||
| return v.as_base_dtype().as_datatype_enum(); | return v.as_base_dtype().as_datatype_enum(); | ||||
| @@ -330,7 +343,7 @@ namespace Tensorflow | |||||
| switch (attr_def.Type) | switch (attr_def.Type) | ||||
| { | { | ||||
| case "string": | case "string": | ||||
| attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
| attr_value.S = _MakeStr((string)value, attr_def); | |||||
| break; | break; | ||||
| case "type": | case "type": | ||||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | attr_value.Type = _MakeType((TF_DataType)value, attr_def); | ||||
| @@ -363,6 +376,9 @@ namespace Tensorflow | |||||
| else if (value is int[] val3) | else if (value is int[] val3) | ||||
| attr_value.Shape = tensor_util.as_shape(val3); | attr_value.Shape = tensor_util.as_shape(val3); | ||||
| break; | |||||
| case "list(shape)": | |||||
| attr_value.List.Shape.AddRange((value as TensorShape[]).Select(x => _MakeShape(x, attr_def))); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Queues | |||||
| { | |||||
| /// <summary> | |||||
| /// A FIFOQueue that supports batching variable-sized tensors by padding. | |||||
| /// </summary> | |||||
| public class PaddingFIFOQueue : QueueBase | |||||
| { | |||||
| public PaddingFIFOQueue(int capacity, | |||||
| TF_DataType[] dtypes, | |||||
| TensorShape[] shapes, | |||||
| string[] names = null, | |||||
| string shared_name = null, | |||||
| string name = "padding_fifo_queue") | |||||
| : base(dtypes: dtypes, shapes: shapes, names: names) | |||||
| { | |||||
| _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||||
| component_types: dtypes, | |||||
| shapes: shapes, | |||||
| capacity: capacity, | |||||
| shared_name: shared_name, | |||||
| name: name); | |||||
| _name = _queue_ref.op.name.Split('/').Last(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Queues | |||||
| { | |||||
| public class QueueBase | |||||
| { | |||||
| protected TF_DataType[] _dtypes; | |||||
| protected TensorShape[] _shapes; | |||||
| protected string[] _names; | |||||
| protected Tensor _queue_ref; | |||||
| protected string _name; | |||||
| public QueueBase(TF_DataType[] dtypes, TensorShape[] shapes, string[] names) | |||||
| { | |||||
| _dtypes = dtypes; | |||||
| _shapes = shapes; | |||||
| _names = names; | |||||
| } | |||||
| public Operation enqueue(Tensor val, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope => | |||||
| { | |||||
| var vals = new[] { val }; | |||||
| if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||||
| return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope); | |||||
| else | |||||
| return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope); | |||||
| }); | |||||
| } | |||||
| public Tensor[] dequeue_many(int n, string name = null) | |||||
| { | |||||
| if (name == null) | |||||
| name = $"{_name}_DequeueMany"; | |||||
| var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name); | |||||
| //var op = ret[0].op; | |||||
| //var cv = tensor_util.constant_value(op.inputs[1]); | |||||
| //var batch_dim = new Dimension(cv); | |||||
| return _dequeue_return_value(ret); | |||||
| } | |||||
| public Tensor[] _dequeue_return_value(Tensor[] tensors) | |||||
| { | |||||
| if (_names != null) | |||||
| throw new NotImplementedException(""); | |||||
| return tensors; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -22,10 +22,9 @@ namespace Tensorflow | |||||
| public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) | public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) | ||||
| { | { | ||||
| var _attr_N = indices.Length; | |||||
| var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data }); | var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -45,5 +44,58 @@ namespace Tensorflow | |||||
| return (null, null); | return (null, null); | ||||
| } | } | ||||
| public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||||
| int capacity = -1, string container = "", string shared_name = "", | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("PaddingFIFOQueueV2", name, new | |||||
| { | |||||
| component_types, | |||||
| shapes, | |||||
| capacity, | |||||
| container, | |||||
| shared_name | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | |||||
| { | |||||
| handle, | |||||
| components, | |||||
| timeout_ms | |||||
| }); | |||||
| return _op; | |||||
| } | |||||
| public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("QueueEnqueueV2", name, new | |||||
| { | |||||
| handle, | |||||
| components, | |||||
| timeout_ms | |||||
| }); | |||||
| return _op; | |||||
| } | |||||
| public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new | |||||
| { | |||||
| handle, | |||||
| n, | |||||
| component_types, | |||||
| timeout_ms | |||||
| }); | |||||
| return _op.outputs; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,36 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class QueueTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void PaddingFIFOQueue() | |||||
| { | |||||
| var numbers = tf.placeholder(tf.int32); | |||||
| var queue = tf.PaddingFIFOQueue(capacity: 10, dtypes: new[] { tf.int32 }, shapes: new[] { new TensorShape(-1) }); | |||||
| var enqueue = queue.enqueue(numbers); | |||||
| var dequeue_many = queue.dequeue_many(n: 3); | |||||
| using(var sess = tf.Session()) | |||||
| { | |||||
| sess.run(enqueue, (numbers, new[] { 1 })); | |||||
| sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||||
| sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||||
| var result = sess.run(dequeue_many[0]); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||