Browse Source

session.run: added overload with Hashtable as feed dict

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
5b49f73eda
1 changed files with 14 additions and 0 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 14
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -1,5 +1,6 @@
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
@@ -40,6 +41,13 @@ namespace Tensorflow
return _run(fetches, feed_dict);
}

public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}

private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
@@ -89,6 +97,12 @@ namespace Tensorflow
case byte[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool val:
feed_dict_tensor[subfeed_t] = (NDArray) val;
break;
case bool[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
default:
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");


Loading…
Cancel
Save