| @@ -1,17 +1,17 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using NumSharp; | using NumSharp; | ||||
| @@ -19,6 +19,7 @@ using System; | |||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Numerics; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -31,18 +32,18 @@ namespace Tensorflow | |||||
| protected bool _closed; | protected bool _closed; | ||||
| protected int _current_version; | protected int _current_version; | ||||
| protected byte[] _target; | protected byte[] _target; | ||||
| protected IntPtr _session; | |||||
| public Status Status; | |||||
| protected IntPtr _session; | |||||
| public Status Status; | |||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | ||||
| { | |||||
| { | |||||
| _graph = g is null ? ops.get_default_graph() : g; | _graph = g is null ? ops.get_default_graph() : g; | ||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | _target = UTF8Encoding.UTF8.GetBytes(target); | ||||
| SessionOptions newOpts = null; | SessionOptions newOpts = null; | ||||
| if (opts == null) | |||||
| if (opts == null) | |||||
| newOpts = c_api.TF_NewSessionOptions(); | newOpts = c_api.TF_NewSessionOptions(); | ||||
| Status = new Status(); | Status = new Status(); | ||||
| @@ -50,7 +51,7 @@ namespace Tensorflow | |||||
| _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); | _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); | ||||
| // dispose newOpts | // dispose newOpts | ||||
| if (opts == null) | |||||
| if (opts == null) | |||||
| c_api.TF_DeleteSessionOptions(newOpts); | c_api.TF_DeleteSessionOptions(newOpts); | ||||
| Status.Check(true); | Status.Check(true); | ||||
| @@ -63,7 +64,7 @@ namespace Tensorflow | |||||
| public virtual NDArray run(object fetches, Hashtable feed_dict = null) | public virtual NDArray run(object fetches, Hashtable feed_dict = null) | ||||
| { | { | ||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | ||||
| return _run(fetches, feed_items); | return _run(fetches, feed_items); | ||||
| } | } | ||||
| @@ -86,57 +87,8 @@ namespace Tensorflow | |||||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | ||||
| { | { | ||||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | ||||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||||
| switch (subfeed_val) | |||||
| { | |||||
| case IntPtr val: | |||||
| feed_dict_tensor[subfeed_t] = val; | |||||
| break; | |||||
| case NDArray val: | |||||
| feed_dict_tensor[subfeed_t] = val; | |||||
| break; | |||||
| case float val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case double val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case short val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case int val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case long val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case long[] val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case int[] val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case string val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case byte[] val: | |||||
| feed_dict_tensor[subfeed_t] = np.array(val); | |||||
| break; | |||||
| case char[] val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case bool val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| case bool[] val: | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | |||||
| default: | |||||
| Console.WriteLine($"can't handle data type of subfeed_val"); | |||||
| throw new NotImplementedException("_run subfeed"); | |||||
| } | |||||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | ||||
| } | } | ||||
| } | } | ||||
| @@ -175,26 +127,78 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
| { | { | ||||
| var feeds = feed_dict.Select(x => | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | { | ||||
| if (x.Key is Tensor tensor) | if (x.Key is Tensor tensor) | ||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| case IntPtr pointer: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer); | |||||
| case Tensor t1: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1); | |||||
| case NDArray nd: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd, tensor.dtype)); | |||||
| case int intVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal)); | |||||
| case float floatVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal)); | |||||
| case double doubleVal: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal)); | |||||
| #if _REGEN | |||||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||||
| %foreach types% | |||||
| case #1 v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case #1[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| % | |||||
| #else | |||||
| case sbyte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case sbyte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case byte[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case short[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ushort[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case int[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case uint[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case long[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case ulong[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case float[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case double[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Complex[] v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| #endif | |||||
| case bool v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||||
| case string v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case IntPtr v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| case Tensor v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| case NDArray v: | |||||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| default: | default: | ||||
| throw new NotImplementedException("feed_dict data type"); | |||||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||||
| } | } | ||||
| } | } | ||||
| throw new NotImplementedException("_do_run.feed_dict"); | throw new NotImplementedException("_do_run.feed_dict"); | ||||