| @@ -0,0 +1,263 @@ | |||
| using Microsoft.CodeAnalysis.CSharp; | |||
| using Protobuf.Text; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Reflection.Metadata.Ecma335; | |||
| using System.Text; | |||
| using System.Text.RegularExpressions; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.CodeGen | |||
| { | |||
| public class DescriptionGenerator | |||
| { | |||
| private static readonly string replaceStrInner = "~~%~~"; | |||
| private static readonly string replaceStrInnerQuotationMarks = "^%^"; | |||
| Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>(); | |||
| Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>(); | |||
| public DescriptionGenerator(string apiDefDirectory) | |||
| { | |||
| DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); | |||
| int errors = 0; | |||
| foreach (FileInfo file in directory.GetFiles()) | |||
| { | |||
| string target = file.Name.Split('.')[0].Split('_').Last(); | |||
| OpDef op = null; | |||
| try | |||
| { | |||
| op = ReadOpDefs(file.FullName).Op[0]; | |||
| } | |||
| catch | |||
| { | |||
| errors++; | |||
| continue; | |||
| } | |||
| _opDescriptionDefs[target] = op; | |||
| _opDescriptions[target] = new Dictionary<string, string>(); | |||
| foreach (var arg in op.InputArg) | |||
| { | |||
| string argName = arg.Name; | |||
| var token = SyntaxFactory.ParseToken(argName); | |||
| if (token.IsKeyword()) | |||
| { | |||
| argName = $"{argName}_"; | |||
| } | |||
| _opDescriptions[target][argName] = arg.Description ?? ""; | |||
| } | |||
| foreach (var arg in op.Attr) | |||
| { | |||
| var token = SyntaxFactory.ParseToken(arg.Name); | |||
| string realKey = arg.Name; | |||
| if (token.IsKeyword()) | |||
| { | |||
| realKey += "_"; | |||
| } | |||
| _opDescriptions[target][realKey] = arg.Description ?? ""; | |||
| } | |||
| _opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; | |||
| _opDescriptions[target]["DESC"] = op.Description ?? ""; | |||
| } | |||
| Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + | |||
| $"the failed files number is large, or ignore it."); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="sb"></param> | |||
| public void AppendDescription(OpDef fullOp, StringBuilder sb) | |||
| { | |||
| var opName = fullOp.Name; | |||
| if(_opDescriptions.TryGetValue(opName, out var op)) | |||
| { | |||
| var def = _opDescriptionDefs[opName]; | |||
| sb.AppendLine("/// <summary>"); | |||
| sb.AppendLine($"/// {op["SUMMARY"]}"); | |||
| sb.AppendLine("/// </summary>"); | |||
| string totalDesc = op["DESC"]; | |||
| if (!string.IsNullOrEmpty(totalDesc)) | |||
| { | |||
| totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); | |||
| sb.AppendLine("/// <remarks>"); | |||
| string[] lines = totalDesc.Split(replaceStrInner); | |||
| foreach (var line in lines) | |||
| { | |||
| sb.AppendLine($"/// {line}"); | |||
| } | |||
| sb.AppendLine("/// </remarks>"); | |||
| } | |||
| var argNames = GetInputArgNames(fullOp); | |||
| foreach (var argName in argNames) | |||
| { | |||
| if(op.TryGetValue(argName, out var desc)) | |||
| { | |||
| desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); | |||
| string[] lines = desc.Split(replaceStrInner); | |||
| sb.AppendLine($"/// <param name=\"{argName}\">"); | |||
| foreach (var line in lines) | |||
| { | |||
| sb.AppendLine($"/// {line}"); | |||
| } | |||
| sb.AppendLine("/// </param>"); | |||
| } | |||
| else | |||
| { | |||
| sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||
| } | |||
| } | |||
| List<string> returnValueDescs = new(); | |||
| foreach (var arg in def.OutputArg) | |||
| { | |||
| if (!string.IsNullOrEmpty(arg.Description)) | |||
| { | |||
| returnValueDescs.Add($"{arg.Name}: {arg.Description}"); | |||
| } | |||
| } | |||
| string returnValueDesc = ""; | |||
| if (returnValueDescs.Count > 0) | |||
| { | |||
| returnValueDesc = string.Join(" && ", returnValueDescs); | |||
| } | |||
| sb.AppendLine($"/// <returns>{returnValueDesc}</returns>"); | |||
| } | |||
| else | |||
| { | |||
| sb.AppendLine("/// <summary>"); | |||
| sb.AppendLine($"///"); | |||
| sb.AppendLine("/// </summary>"); | |||
| var argNames = GetInputArgNames(fullOp); | |||
| foreach (var argName in argNames) | |||
| { | |||
| sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||
| } | |||
| sb.AppendLine($"/// <returns></returns>"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="op"> | |||
| /// </param> | |||
| /// <returns></returns> | |||
| /// <remarks></remarks> | |||
| public List<string> GetInputArgNames(OpDef op) | |||
| { | |||
| List<string> names = new(); | |||
| foreach (var arg in op.InputArg) | |||
| { | |||
| string argName = arg.Name; | |||
| var token = SyntaxFactory.ParseToken(argName); | |||
| if (token.IsKeyword()) | |||
| { | |||
| argName = $"{argName}_"; | |||
| } | |||
| names.Add(argName); | |||
| } | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
| foreach (var (key, typeStr, value) in attrValueDic) | |||
| { | |||
| var token = SyntaxFactory.ParseToken(key); | |||
| string realKey = key; | |||
| if (token.IsKeyword()) | |||
| { | |||
| realKey += "_"; | |||
| } | |||
| names.Add(realKey); | |||
| } | |||
| return names; | |||
| } | |||
| private static OpList ReadOpDefs(string path) | |||
| { | |||
| var text = File.ReadAllText(path); | |||
| text = RemoveLintTags(text); | |||
| text = PreProcessText(text); | |||
| string pattern = @"<<END([\s\S]*?)END"; | |||
| // 定义用于替换的字符串 | |||
| string replaceStrPrefix = "\""; | |||
| string replaceStrSuffix = "\""; | |||
| // 将匹配到的文本段全部替换 | |||
| string replacedText = Regex.Replace(text, pattern, match => { | |||
| string matchedText = match.Value; | |||
| string innerText = match.Groups[1].Value; | |||
| innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) | |||
| .Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 | |||
| return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 | |||
| }, RegexOptions.Multiline); | |||
| var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText); | |||
| return opDefs; | |||
| } | |||
| static string PreProcessText(string input) | |||
| { | |||
| int depth = 0; | |||
| int endBlockDepth = -1; | |||
| StringBuilder sb = new StringBuilder(); | |||
| for (int i = 0; i < input.Length; i++) | |||
| { | |||
| char c = input[i]; | |||
| if (c == '{') | |||
| { | |||
| depth++; | |||
| sb.Append(c); | |||
| } | |||
| else if (c == '}') | |||
| { | |||
| if (depth == endBlockDepth) | |||
| { | |||
| sb.Append("END\n"); | |||
| endBlockDepth = -1; | |||
| } | |||
| sb.Append(c); | |||
| depth--; | |||
| } | |||
| else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END") | |||
| { | |||
| endBlockDepth = depth; | |||
| sb.Append("<<END"); | |||
| i += 4; | |||
| } | |||
| else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END") | |||
| { | |||
| endBlockDepth = -1; | |||
| sb.Append("END"); | |||
| i += 2; | |||
| } | |||
| else | |||
| { | |||
| sb.Append(c); | |||
| } | |||
| } | |||
| string output = sb.ToString(); | |||
| return output; | |||
| } | |||
| static string RemoveLintTags(string input) | |||
| { | |||
| string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); | |||
| StringBuilder sb = new StringBuilder(); | |||
| foreach (string line in lines) | |||
| { | |||
| if (!line.TrimStart().StartsWith("# LINT")) | |||
| { | |||
| sb.AppendLine(line); | |||
| } | |||
| } | |||
| return sb.ToString().TrimEnd(); | |||
| } | |||
| } | |||
| } | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow.CodeGen | |||
| // begin to write main body | |||
| sb.AppendLine("var _ctx = tf.Context;"); | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
| // deal with dynamic default values. | |||
| foreach(var (name, expr) in dynamicDefaultValues) | |||
| { | |||
| @@ -183,7 +183,7 @@ namespace Tensorflow.CodeGen | |||
| sb.Append($"Tensor {argName}, "); | |||
| } | |||
| } | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||
| foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | |||
| { | |||
| var token = SyntaxFactory.ParseToken(key); | |||
| @@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen | |||
| } | |||
| sb.Append("}, attrs = new Dictionary<string, object>(){ "); | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
| foreach (var (key, _, _) in attrValueDic) | |||
| { | |||
| sb.Append($"[\"{key}\"] = {key}, "); | |||
| @@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen | |||
| } | |||
| sb.Append($"{inputArgRealName}, "); | |||
| } | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
| foreach (var (key, _, _) in attrValueDic) | |||
| { | |||
| string keyRealName = key; | |||
| @@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen | |||
| sb.Append($"Tensor {argName}, "); | |||
| } | |||
| } | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
| foreach (var (key, typeStr, _) in attrValueDic) | |||
| { | |||
| var token = SyntaxFactory.ParseToken(key); | |||
| @@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen | |||
| } | |||
| sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | |||
| } | |||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||
| foreach (var (key, _, _) in attrValueDic) | |||
| { | |||
| sb.AppendLine($"keywords[\"{key}\"] = {key};"); | |||
| @@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen | |||
| sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | |||
| } | |||
| // name, type string, default value | |||
| public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||
| { | |||
| dynamicDefaultValues = new(); | |||
| List<(string, string, string)> res = new(); | |||
| foreach (var attr in op.Attr) | |||
| { | |||
| if (attr.Type == "type") | |||
| { | |||
| bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||
| if (!found) | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||
| string enumPath = typeof(TF_DataType).Name + "." + name; | |||
| res.Add((attr.Name, "TF_DataType", enumPath)); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||
| } | |||
| } | |||
| } | |||
| else if (attr.Type == "int") | |||
| { | |||
| if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||
| { | |||
| continue; | |||
| } | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||
| { | |||
| res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "int", "0")); | |||
| } | |||
| } | |||
| else if (attr.Type == "float") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||
| { | |||
| res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "float", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "string") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
| { | |||
| res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "string", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "bool") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||
| { | |||
| res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "bool", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "shape") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||
| { | |||
| if (attr.DefaultValue.Shape.UnknownRank) | |||
| { | |||
| res.Add((attr.Name, "Shape", $"null")); | |||
| } | |||
| else | |||
| { | |||
| Shape shape = new Shape(attr.DefaultValue.Shape); | |||
| string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "Shape", $"null")); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "Shape", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(type)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| List<TF_DataType> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.Type) | |||
| { | |||
| values.Add(value.as_tf_dtype()); | |||
| } | |||
| string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "TF_DataType[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(shape)") | |||
| { | |||
| res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "list(string)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
| { | |||
| List<string> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.S) | |||
| { | |||
| values.Add(value.ToStringUtf8()); | |||
| } | |||
| string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "string[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "string[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(int)") | |||
| { | |||
| if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| List<int> values = new(); | |||
| foreach(var value in attr.DefaultValue.List.I) | |||
| { | |||
| values.Add((int)value); | |||
| } | |||
| string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "int[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "int[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(float)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| List<float> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.F) | |||
| { | |||
| values.Add(value); | |||
| } | |||
| string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "float[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "float[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "func") | |||
| { | |||
| res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "list(func)") | |||
| { | |||
| res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "tensor") | |||
| { | |||
| res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| private static bool HasRefArgs(OpDef op) | |||
| { | |||
| return op.InputArg.Any(x => x.IsRef); | |||
| @@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen | |||
| private string _basePath; | |||
| private Dictionary<string, OpDef> _opMap; | |||
| private OpClassifier _opClassifier; | |||
| private FunctionGenerator _g = new(); | |||
| private FunctionGenerator _fg = new(); | |||
| private DescriptionGenerator _dg; | |||
| public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) | |||
| public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) | |||
| { | |||
| _basePath = basePath; | |||
| var opDefs = ReadAllOpDefs(opDefFilename); | |||
| var opDefs = Utils.ReadAllOpDefs(opDefFilename); | |||
| _opMap = opDefs.Op.ToDictionary( | |||
| x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); | |||
| x => Utils.ConvertToUnderscore(x.Name), x => x); | |||
| _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | |||
| _dg = new DescriptionGenerator(apiDefFilesDirectory); | |||
| } | |||
| public void WriteAll() | |||
| @@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen | |||
| if(_opMap.ContainsKey(funcName)) | |||
| { | |||
| var opDef = _opMap[funcName]; | |||
| _g.AppendFunction(opDef, sb); | |||
| // write the descriptions. | |||
| _dg.AppendDescription(opDef, sb); | |||
| // write the function body. | |||
| _fg.AppendFunction(opDef, sb); | |||
| } | |||
| else if (funcName.StartsWith("_")) | |||
| { | |||
| var opDef = _opMap[funcName.Substring(1)]; | |||
| _g.AppendFunction(opDef, sb); | |||
| _fg.AppendFunction(opDef, sb); | |||
| } | |||
| } | |||
| @@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen | |||
| File.WriteAllText(fullFilePath, sb.ToString()); | |||
| } | |||
| } | |||
| private OpList ReadAllOpDefs(string path) | |||
| { | |||
| var text = File.ReadAllText(path); | |||
| var opDefs = OpList.Parser.ParseText(text); | |||
| return opDefs; | |||
| } | |||
| } | |||
| } | |||
| @@ -5,10 +5,9 @@ using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.CodeGen; | |||
| //Console.WriteLine(Utils.ConvertToUnderscore("LRN")); | |||
| GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | |||
| @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | |||
| @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | |||
| @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | |||
| writer.WriteAll(); | |||
| @@ -9,11 +9,11 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | |||
| <PackageReference Include="TensorFlow.NET" Version="0.100.5" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | |||
| <ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using Protobuf.Text; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Reflection.Metadata.Ecma335; | |||
| @@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen | |||
| return result.ToString(); | |||
| } | |||
| public static OpList ReadAllOpDefs(string path) | |||
| { | |||
| var text = File.ReadAllText(path); | |||
| var opDefs = OpList.Parser.ParseText(text); | |||
| return opDefs; | |||
| } | |||
| // name, type string, default value | |||
| public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||
| { | |||
| dynamicDefaultValues = new(); | |||
| List<(string, string, string)> res = new(); | |||
| foreach (var attr in op.Attr) | |||
| { | |||
| if (attr.Type == "type") | |||
| { | |||
| bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||
| if (!found) | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||
| string enumPath = typeof(TF_DataType).Name + "." + name; | |||
| res.Add((attr.Name, "TF_DataType", enumPath)); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||
| } | |||
| } | |||
| } | |||
| else if (attr.Type == "int") | |||
| { | |||
| if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||
| { | |||
| continue; | |||
| } | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||
| { | |||
| res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "int", "0")); | |||
| } | |||
| } | |||
| else if (attr.Type == "float") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||
| { | |||
| res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "float", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "string") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
| { | |||
| res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "string", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "bool") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||
| { | |||
| res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "bool", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "shape") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||
| { | |||
| if (attr.DefaultValue.Shape.UnknownRank) | |||
| { | |||
| res.Add((attr.Name, "Shape", $"null")); | |||
| } | |||
| else | |||
| { | |||
| Shape shape = new Shape(attr.DefaultValue.Shape); | |||
| string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "Shape", $"null")); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "Shape", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(type)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| List<TF_DataType> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.Type) | |||
| { | |||
| values.Add(value.as_tf_dtype()); | |||
| } | |||
| string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "TF_DataType[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(shape)") | |||
| { | |||
| res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "list(string)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||
| { | |||
| List<string> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.S) | |||
| { | |||
| values.Add(value.ToStringUtf8()); | |||
| } | |||
| string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "string[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "string[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(int)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| List<int> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.I) | |||
| { | |||
| values.Add((int)value); | |||
| } | |||
| string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "int[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "int[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "list(float)") | |||
| { | |||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| List<float> values = new(); | |||
| foreach (var value in attr.DefaultValue.List.F) | |||
| { | |||
| values.Add(value); | |||
| } | |||
| string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||
| dynamicDefaultValues[attr.Name] = expression; | |||
| res.Add((attr.Name, "float[]", $"null")); | |||
| } | |||
| else | |||
| { | |||
| res.Add((attr.Name, "float[]", "NOVALUE")); | |||
| } | |||
| } | |||
| else if (attr.Type == "func") | |||
| { | |||
| res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "list(func)") | |||
| { | |||
| res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||
| } | |||
| else if (attr.Type == "tensor") | |||
| { | |||
| res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||