|
|
|
@@ -15,7 +15,9 @@ |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using NumSharp; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
{ |
|
|
|
@@ -37,10 +39,21 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public virtual NDArray build_results(List<NDArray> values) |
|
|
|
{ |
|
|
|
var type = values[0].GetType(); |
|
|
|
var nd = new NDArray(type, values.Count); |
|
|
|
nd.ReplaceData(values.ToArray()); |
|
|
|
return nd; |
|
|
|
// if they're all scalar value |
|
|
|
bool isAllScalars = values.Count(x => x.ndim == 0) == values.Count; |
|
|
|
if (isAllScalars) |
|
|
|
{ |
|
|
|
var type = values[0].dtype; |
|
|
|
switch(Type.GetTypeCode(type)) |
|
|
|
{ |
|
|
|
case TypeCode.Single: |
|
|
|
return np.array(values.Select(x => x.GetSingle(0)).ToArray()); |
|
|
|
default: |
|
|
|
throw new NotImplementedException("build_results"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return np.stack(values.ToArray()); |
|
|
|
} |
|
|
|
|
|
|
|
public virtual List<ITensorOrOperation> unique_fetches() |
|
|
|
|