From 1aff986cf735fa9c81344f2bf982c76492718a5d Mon Sep 17 00:00:00 2001 From: Brendan Mulcahy Date: Sun, 24 Nov 2019 11:49:35 -0500 Subject: [PATCH] Nearly working tf.scan (does not build) --- src/TensorFlowNET.Core/APIs/tf.scan.cs | 35 +++ .../Operations/functional_ops.cs | 199 ++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 src/TensorFlowNET.Core/APIs/tf.scan.cs create mode 100644 src/TensorFlowNET.Core/Operations/functional_ops.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.scan.cs b/src/TensorFlowNET.Core/APIs/tf.scan.cs new file mode 100644 index 00000000..439b0512 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.scan.cs @@ -0,0 +1,35 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor scan( + Func fn, + Tensor elems, + IInitializer initializer = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + bool reverse = false, + string name = null) => functional_ops.scan(fn, elems, initializer, parallel_iterations, back_prop, + swap_memory, infer_shape, reverse, name); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs new file mode 100644 index 00000000..85292851 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -0,0 +1,199 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class functional_ops + { + public static Tensor scan( + Func fn, + Tensor elems, + IInitializer initializer = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + bool reverse = false, + string name = null) + { + bool input_is_sequence = nest.is_sequence(elems); + + List input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List {x}; + object input_pack(List x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x; + + bool output_is_sequence; + Func> output_flatten; + if (initializer == null) + { + output_is_sequence = input_is_sequence; + output_flatten = input_flatten; + //output_pack = input_pack + } + else + { + output_is_sequence = nest.is_sequence(initializer); + output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List {x}; + } + + object output_pack(List x) + { + return output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; + } + + var elems_flat = input_flatten(elems); + + bool in_graph_mode = true; // todo not context.executing_eagerly() + + //with ops.name_scope(name, "scan", elems_flat): + return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope => + { + //if (in_graph_mode) + //{ + // // Any get_variable calls in fn will cache the first call locally + // // and not issue repeated network I/O requests for each iteration. + // var varscope = variable_scope.get_variable_scope(); + // bool varscope_caching_device_was_none = false; + // if (varscope.caching_device = null) + // { + // // varscope.set_caching_device(lambda op: op.device) + // // varscope_caching_device_was_none = True + // } + //} + + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList(); + + // # Convert elems to tensor array. n may be known statically. + var n = tensor_shape.dimension_value(elems_flat[0].shape[0]); + //if (n == null) + //{ + // n = array_ops.shape(elems_flat[0])[0]; + //} + + // # TensorArrays are always flat + var elems_ta = elems_flat.Select(elem => new TensorArray( + elem.dtype, + size: tf.constant(n), + dynamic_size: false, + element_shape: elem.shape[0], //1: + infer_shape: true)).ToList(); + + for (int index = 0; index < elems_ta.Count; index++) + { + elems_ta[index].unstack(elems_flat[index]); + } + + List a_flat; + int i; + if (initializer == null) + { + // a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] + a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList(); + i = 1; + } + else + { + List initializer_flat = output_flatten(initializer as Tensor); + a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList(); + i = 0; + } + + var accs_ta = a_flat.Select(init => new TensorArray( + dtype: init.dtype, + size: tf.constant(n), + element_shape: infer_shape ? init.shape : null, + dynamic_size: false, + infer_shape: infer_shape)).ToList(); + + // if initializer is None: + if (initializer == null) + { + for (int index = 0; index < accs_ta.Count; index++) + { + accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]); + } + } + + (int, List, List) compute(int _i, List a_flat_, List tas) + { + var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.convert_to_tensor(_i))).ToList()); + var packed_a = output_pack(a_flat_); + var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? + + var flat_a_out = output_flatten(a_out); + for (int j = 0; j < tas.Count; j++) + { + tas[j].write(tf.convert_to_tensor(i), flat_a_out[j]); // todo brendan convert to tensor + } + + var next_i = reverse ? _i-- : _i++; + return (next_i, flat_a_out, tas); + } + + int initial_i; + Func condition; + if (reverse) + { + initial_i = n - 1 - i; + // condition = lambda i, _1, _2: i >= 0 + condition = x => tf.convert_to_tensor(x >= 0); + } + else + { + initial_i = i; + // condition = lambda i, _1, _2: i < n + condition = x => tf.convert_to_tensor(x < n); + } + + List r_a = + control_flow_ops.while_loop( + condition, + compute, + (initial_i, a_flat, accs_ta), + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + maximum_iterations: n); + + var results_flat = r_a.Select(r => r.stack()).ToList(); + + var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); + + foreach (var elem in elems_flat) // for elem in elems_flat[1:]: + { + n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); + } + + foreach (Tensor r in results_flat) + { + r.set_shape(new TensorShape(n_static).concatenate(r.shape[0])); //r.shape[1:] + } + + // if in_graph_mode and varscope_caching_device_was_none: + // varscope.set_caching_device(None) + + return output_pack(results_flat); + }); + } + } +} +