| @@ -1,17 +1,17 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| @@ -19,6 +19,7 @@ using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Numerics; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| @@ -31,18 +32,18 @@ namespace Tensorflow | |||
| protected bool _closed; | |||
| protected int _current_version; | |||
| protected byte[] _target; | |||
| protected IntPtr _session; | |||
| public Status Status; | |||
| protected IntPtr _session; | |||
| public Status Status; | |||
| public Graph graph => _graph; | |||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
| { | |||
| { | |||
| _graph = g is null ? ops.get_default_graph() : g; | |||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||
| SessionOptions newOpts = null; | |||
| if (opts == null) | |||
| if (opts == null) | |||
| newOpts = c_api.TF_NewSessionOptions(); | |||
| Status = new Status(); | |||
| @@ -50,7 +51,7 @@ namespace Tensorflow | |||
| _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); | |||
| // dispose newOpts | |||
| if (opts == null) | |||
| if (opts == null) | |||
| c_api.TF_DeleteSessionOptions(newOpts); | |||
| Status.Check(true); | |||
| @@ -63,7 +64,7 @@ namespace Tensorflow | |||
| public virtual NDArray run(object fetches, Hashtable feed_dict = null) | |||
| { | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
| return _run(fetches, feed_items); | |||
| } | |||
| @@ -86,57 +87,8 @@ namespace Tensorflow | |||
| foreach (var (subfeed, subfeed_val) in feed_fn(feed)) | |||
| { | |||
| var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); | |||
| var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); | |||
| switch (subfeed_val) | |||
| { | |||
| case IntPtr val: | |||
| feed_dict_tensor[subfeed_t] = val; | |||
| break; | |||
| case NDArray val: | |||
| feed_dict_tensor[subfeed_t] = val; | |||
| break; | |||
| case float val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case double val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case short val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case int val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case long val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case long[] val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case int[] val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case string val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case byte[] val: | |||
| feed_dict_tensor[subfeed_t] = np.array(val); | |||
| break; | |||
| case char[] val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case bool val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| case bool[] val: | |||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
| break; | |||
| default: | |||
| Console.WriteLine($"can't handle data type of subfeed_val"); | |||
| throw new NotImplementedException("_run subfeed"); | |||
| } | |||
| //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used | |||
| feed_dict_tensor[subfeed_t] = subfeed_val; | |||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
| } | |||
| } | |||
| @@ -175,26 +127,78 @@ namespace Tensorflow | |||
| /// </returns> | |||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
| { | |||
| var feeds = feed_dict.Select(x => | |||
| var feeds = feed_dict.Select(x => | |||
| { | |||
| if (x.Key is Tensor tensor) | |||
| { | |||
| switch (x.Value) | |||
| { | |||
| case IntPtr pointer: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer); | |||
| case Tensor t1: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1); | |||
| case NDArray nd: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd, tensor.dtype)); | |||
| case int intVal: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal)); | |||
| case float floatVal: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal)); | |||
| case double doubleVal: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal)); | |||
| #if _REGEN | |||
| %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
| %foreach types% | |||
| case #1 v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case #1[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| % | |||
| #else | |||
| case sbyte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case sbyte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case byte[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case short[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ushort[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case int[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case uint[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case long[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case ulong[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case float[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case double[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Complex[] v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| #endif | |||
| case bool v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); | |||
| case string v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case IntPtr v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
| case Tensor v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
| case NDArray v: | |||
| return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
| default: | |||
| throw new NotImplementedException("feed_dict data type"); | |||
| throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); | |||
| } | |||
| } | |||
| throw new NotImplementedException("_do_run.feed_dict"); | |||