| @@ -0,0 +1,263 @@ | |||||
| using Microsoft.CodeAnalysis.CSharp; | |||||
| using Protobuf.Text; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Reflection.Metadata.Ecma335; | |||||
| using System.Text; | |||||
| using System.Text.RegularExpressions; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.CodeGen | |||||
| { | |||||
| public class DescriptionGenerator | |||||
| { | |||||
| private static readonly string replaceStrInner = "~~%~~"; | |||||
| private static readonly string replaceStrInnerQuotationMarks = "^%^"; | |||||
| Dictionary<string, Dictionary<string, string>> _opDescriptions = new Dictionary<string, Dictionary<string, string>>(); | |||||
| Dictionary<string, OpDef> _opDescriptionDefs = new Dictionary<string, OpDef>(); | |||||
| public DescriptionGenerator(string apiDefDirectory) | |||||
| { | |||||
| DirectoryInfo directory = new DirectoryInfo(apiDefDirectory); | |||||
| int errors = 0; | |||||
| foreach (FileInfo file in directory.GetFiles()) | |||||
| { | |||||
| string target = file.Name.Split('.')[0].Split('_').Last(); | |||||
| OpDef op = null; | |||||
| try | |||||
| { | |||||
| op = ReadOpDefs(file.FullName).Op[0]; | |||||
| } | |||||
| catch | |||||
| { | |||||
| errors++; | |||||
| continue; | |||||
| } | |||||
| _opDescriptionDefs[target] = op; | |||||
| _opDescriptions[target] = new Dictionary<string, string>(); | |||||
| foreach (var arg in op.InputArg) | |||||
| { | |||||
| string argName = arg.Name; | |||||
| var token = SyntaxFactory.ParseToken(argName); | |||||
| if (token.IsKeyword()) | |||||
| { | |||||
| argName = $"{argName}_"; | |||||
| } | |||||
| _opDescriptions[target][argName] = arg.Description ?? ""; | |||||
| } | |||||
| foreach (var arg in op.Attr) | |||||
| { | |||||
| var token = SyntaxFactory.ParseToken(arg.Name); | |||||
| string realKey = arg.Name; | |||||
| if (token.IsKeyword()) | |||||
| { | |||||
| realKey += "_"; | |||||
| } | |||||
| _opDescriptions[target][realKey] = arg.Description ?? ""; | |||||
| } | |||||
| _opDescriptions[target]["SUMMARY"] = op.Summary ?? ""; | |||||
| _opDescriptions[target]["DESC"] = op.Description ?? ""; | |||||
| } | |||||
| Console.WriteLine($"Warning: {errors} description files cannot be analyzed! Please revise it if " + | |||||
| $"the failed files number is large, or ignore it."); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="sb"></param> | |||||
| public void AppendDescription(OpDef fullOp, StringBuilder sb) | |||||
| { | |||||
| var opName = fullOp.Name; | |||||
| if(_opDescriptions.TryGetValue(opName, out var op)) | |||||
| { | |||||
| var def = _opDescriptionDefs[opName]; | |||||
| sb.AppendLine("/// <summary>"); | |||||
| sb.AppendLine($"/// {op["SUMMARY"]}"); | |||||
| sb.AppendLine("/// </summary>"); | |||||
| string totalDesc = op["DESC"]; | |||||
| if (!string.IsNullOrEmpty(totalDesc)) | |||||
| { | |||||
| totalDesc = totalDesc.Replace(replaceStrInnerQuotationMarks, "\""); | |||||
| sb.AppendLine("/// <remarks>"); | |||||
| string[] lines = totalDesc.Split(replaceStrInner); | |||||
| foreach (var line in lines) | |||||
| { | |||||
| sb.AppendLine($"/// {line}"); | |||||
| } | |||||
| sb.AppendLine("/// </remarks>"); | |||||
| } | |||||
| var argNames = GetInputArgNames(fullOp); | |||||
| foreach (var argName in argNames) | |||||
| { | |||||
| if(op.TryGetValue(argName, out var desc)) | |||||
| { | |||||
| desc = desc.Replace(replaceStrInnerQuotationMarks, "\""); | |||||
| string[] lines = desc.Split(replaceStrInner); | |||||
| sb.AppendLine($"/// <param name=\"{argName}\">"); | |||||
| foreach (var line in lines) | |||||
| { | |||||
| sb.AppendLine($"/// {line}"); | |||||
| } | |||||
| sb.AppendLine("/// </param>"); | |||||
| } | |||||
| else | |||||
| { | |||||
| sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||||
| } | |||||
| } | |||||
| List<string> returnValueDescs = new(); | |||||
| foreach (var arg in def.OutputArg) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(arg.Description)) | |||||
| { | |||||
| returnValueDescs.Add($"{arg.Name}: {arg.Description}"); | |||||
| } | |||||
| } | |||||
| string returnValueDesc = ""; | |||||
| if (returnValueDescs.Count > 0) | |||||
| { | |||||
| returnValueDesc = string.Join(" && ", returnValueDescs); | |||||
| } | |||||
| sb.AppendLine($"/// <returns>{returnValueDesc}</returns>"); | |||||
| } | |||||
| else | |||||
| { | |||||
| sb.AppendLine("/// <summary>"); | |||||
| sb.AppendLine($"///"); | |||||
| sb.AppendLine("/// </summary>"); | |||||
| var argNames = GetInputArgNames(fullOp); | |||||
| foreach (var argName in argNames) | |||||
| { | |||||
| sb.AppendLine($"/// <param name=\"{argName}\"></param>"); | |||||
| } | |||||
| sb.AppendLine($"/// <returns></returns>"); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op"> | |||||
| /// </param> | |||||
| /// <returns></returns> | |||||
| /// <remarks></remarks> | |||||
| public List<string> GetInputArgNames(OpDef op) | |||||
| { | |||||
| List<string> names = new(); | |||||
| foreach (var arg in op.InputArg) | |||||
| { | |||||
| string argName = arg.Name; | |||||
| var token = SyntaxFactory.ParseToken(argName); | |||||
| if (token.IsKeyword()) | |||||
| { | |||||
| argName = $"{argName}_"; | |||||
| } | |||||
| names.Add(argName); | |||||
| } | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
| foreach (var (key, typeStr, value) in attrValueDic) | |||||
| { | |||||
| var token = SyntaxFactory.ParseToken(key); | |||||
| string realKey = key; | |||||
| if (token.IsKeyword()) | |||||
| { | |||||
| realKey += "_"; | |||||
| } | |||||
| names.Add(realKey); | |||||
| } | |||||
| return names; | |||||
| } | |||||
| private static OpList ReadOpDefs(string path) | |||||
| { | |||||
| var text = File.ReadAllText(path); | |||||
| text = RemoveLintTags(text); | |||||
| text = PreProcessText(text); | |||||
| string pattern = @"<<END([\s\S]*?)END"; | |||||
| // 定义用于替换的字符串 | |||||
| string replaceStrPrefix = "\""; | |||||
| string replaceStrSuffix = "\""; | |||||
| // 将匹配到的文本段全部替换 | |||||
| string replacedText = Regex.Replace(text, pattern, match => { | |||||
| string matchedText = match.Value; | |||||
| string innerText = match.Groups[1].Value; | |||||
| innerText = innerText.Replace("\"", replaceStrInnerQuotationMarks) | |||||
| .Replace("\r\n", replaceStrInner).Replace("\n", replaceStrInner); // 替换内部换行符 | |||||
| return replaceStrPrefix + innerText + replaceStrSuffix; // 替换首尾 | |||||
| }, RegexOptions.Multiline); | |||||
| var opDefs = new TextParser(TextParser.Settings.Default.WithIgnoreUnknownFields(true)).Parse<OpList>(replacedText); | |||||
| return opDefs; | |||||
| } | |||||
| static string PreProcessText(string input) | |||||
| { | |||||
| int depth = 0; | |||||
| int endBlockDepth = -1; | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| for (int i = 0; i < input.Length; i++) | |||||
| { | |||||
| char c = input[i]; | |||||
| if (c == '{') | |||||
| { | |||||
| depth++; | |||||
| sb.Append(c); | |||||
| } | |||||
| else if (c == '}') | |||||
| { | |||||
| if (depth == endBlockDepth) | |||||
| { | |||||
| sb.Append("END\n"); | |||||
| endBlockDepth = -1; | |||||
| } | |||||
| sb.Append(c); | |||||
| depth--; | |||||
| } | |||||
| else if (c == '<' && i + 5 < input.Length && input.Substring(i, 5) == "<<END") | |||||
| { | |||||
| endBlockDepth = depth; | |||||
| sb.Append("<<END"); | |||||
| i += 4; | |||||
| } | |||||
| else if (c == 'E' && i + 3 < input.Length && input.Substring(i, 3) == "END") | |||||
| { | |||||
| endBlockDepth = -1; | |||||
| sb.Append("END"); | |||||
| i += 2; | |||||
| } | |||||
| else | |||||
| { | |||||
| sb.Append(c); | |||||
| } | |||||
| } | |||||
| string output = sb.ToString(); | |||||
| return output; | |||||
| } | |||||
| static string RemoveLintTags(string input) | |||||
| { | |||||
| string[] lines = input.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| foreach (string line in lines) | |||||
| { | |||||
| if (!line.TrimStart().StartsWith("# LINT")) | |||||
| { | |||||
| sb.AppendLine(line); | |||||
| } | |||||
| } | |||||
| return sb.ToString().TrimEnd(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -44,7 +44,7 @@ namespace Tensorflow.CodeGen | |||||
| // begin to write main body | // begin to write main body | ||||
| sb.AppendLine("var _ctx = tf.Context;"); | sb.AppendLine("var _ctx = tf.Context;"); | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
| // deal with dynamic default values. | // deal with dynamic default values. | ||||
| foreach(var (name, expr) in dynamicDefaultValues) | foreach(var (name, expr) in dynamicDefaultValues) | ||||
| { | { | ||||
| @@ -183,7 +183,7 @@ namespace Tensorflow.CodeGen | |||||
| sb.Append($"Tensor {argName}, "); | sb.Append($"Tensor {argName}, "); | ||||
| } | } | ||||
| } | } | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var dynamicDefaultValues); | |||||
| foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) | ||||
| { | { | ||||
| var token = SyntaxFactory.ParseToken(key); | var token = SyntaxFactory.ParseToken(key); | ||||
| @@ -226,7 +226,7 @@ namespace Tensorflow.CodeGen | |||||
| } | } | ||||
| sb.Append("}, attrs = new Dictionary<string, object>(){ "); | sb.Append("}, attrs = new Dictionary<string, object>(){ "); | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
| foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
| { | { | ||||
| sb.Append($"[\"{key}\"] = {key}, "); | sb.Append($"[\"{key}\"] = {key}, "); | ||||
| @@ -252,7 +252,7 @@ namespace Tensorflow.CodeGen | |||||
| } | } | ||||
| sb.Append($"{inputArgRealName}, "); | sb.Append($"{inputArgRealName}, "); | ||||
| } | } | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
| foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
| { | { | ||||
| string keyRealName = key; | string keyRealName = key; | ||||
| @@ -439,7 +439,7 @@ namespace Tensorflow.CodeGen | |||||
| sb.Append($"Tensor {argName}, "); | sb.Append($"Tensor {argName}, "); | ||||
| } | } | ||||
| } | } | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
| foreach (var (key, typeStr, _) in attrValueDic) | foreach (var (key, typeStr, _) in attrValueDic) | ||||
| { | { | ||||
| var token = SyntaxFactory.ParseToken(key); | var token = SyntaxFactory.ParseToken(key); | ||||
| @@ -465,7 +465,7 @@ namespace Tensorflow.CodeGen | |||||
| } | } | ||||
| sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); | ||||
| } | } | ||||
| var attrValueDic = GetAttrsDefaultValue(op, out var _); | |||||
| var attrValueDic = Utils.GetAttrsDefaultValue(op, out var _); | |||||
| foreach (var (key, _, _) in attrValueDic) | foreach (var (key, _, _) in attrValueDic) | ||||
| { | { | ||||
| sb.AppendLine($"keywords[\"{key}\"] = {key};"); | sb.AppendLine($"keywords[\"{key}\"] = {key};"); | ||||
| @@ -473,195 +473,6 @@ namespace Tensorflow.CodeGen | |||||
| sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); | ||||
| } | } | ||||
| // name, type string, default value | |||||
| public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||||
| { | |||||
| dynamicDefaultValues = new(); | |||||
| List<(string, string, string)> res = new(); | |||||
| foreach (var attr in op.Attr) | |||||
| { | |||||
| if (attr.Type == "type") | |||||
| { | |||||
| bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||||
| if (!found) | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
| { | |||||
| string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||||
| string enumPath = typeof(TF_DataType).Name + "." + name; | |||||
| res.Add((attr.Name, "TF_DataType", enumPath)); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "int") | |||||
| { | |||||
| if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||||
| { | |||||
| res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "int", "0")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "float") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||||
| { | |||||
| res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "float", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "string") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
| { | |||||
| res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "string", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "bool") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||||
| { | |||||
| res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "bool", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "shape") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
| { | |||||
| if (attr.DefaultValue.Shape.UnknownRank) | |||||
| { | |||||
| res.Add((attr.Name, "Shape", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| Shape shape = new Shape(attr.DefaultValue.Shape); | |||||
| string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "Shape", $"null")); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "Shape", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(type)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
| { | |||||
| List<TF_DataType> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.Type) | |||||
| { | |||||
| values.Add(value.as_tf_dtype()); | |||||
| } | |||||
| string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "TF_DataType[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(shape)") | |||||
| { | |||||
| res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "list(string)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
| { | |||||
| List<string> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.S) | |||||
| { | |||||
| values.Add(value.ToStringUtf8()); | |||||
| } | |||||
| string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "string[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(int)") | |||||
| { | |||||
| if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| List<int> values = new(); | |||||
| foreach(var value in attr.DefaultValue.List.I) | |||||
| { | |||||
| values.Add((int)value); | |||||
| } | |||||
| string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "int[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "int[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(float)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| List<float> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.F) | |||||
| { | |||||
| values.Add(value); | |||||
| } | |||||
| string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "float[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "float[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "func") | |||||
| { | |||||
| res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "list(func)") | |||||
| { | |||||
| res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "tensor") | |||||
| { | |||||
| res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| private static bool HasRefArgs(OpDef op) | private static bool HasRefArgs(OpDef op) | ||||
| { | { | ||||
| return op.InputArg.Any(x => x.IsRef); | return op.InputArg.Any(x => x.IsRef); | ||||
| @@ -12,16 +12,18 @@ namespace Tensorflow.CodeGen | |||||
| private string _basePath; | private string _basePath; | ||||
| private Dictionary<string, OpDef> _opMap; | private Dictionary<string, OpDef> _opMap; | ||||
| private OpClassifier _opClassifier; | private OpClassifier _opClassifier; | ||||
| private FunctionGenerator _g = new(); | |||||
| private FunctionGenerator _fg = new(); | |||||
| private DescriptionGenerator _dg; | |||||
| public GenOpsWriter(string basePath, string pythonFilesDirectory, string opDefFilename) | |||||
| public GenOpsWriter(string basePath, string pythonFilesDirectory, string apiDefFilesDirectory, string opDefFilename) | |||||
| { | { | ||||
| _basePath = basePath; | _basePath = basePath; | ||||
| var opDefs = ReadAllOpDefs(opDefFilename); | |||||
| var opDefs = Utils.ReadAllOpDefs(opDefFilename); | |||||
| _opMap = opDefs.Op.ToDictionary( | _opMap = opDefs.Op.ToDictionary( | ||||
| x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); | |||||
| x => Utils.ConvertToUnderscore(x.Name), x => x); | |||||
| _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); | ||||
| _dg = new DescriptionGenerator(apiDefFilesDirectory); | |||||
| } | } | ||||
| public void WriteAll() | public void WriteAll() | ||||
| @@ -53,12 +55,17 @@ namespace Tensorflow.CodeGen | |||||
| if(_opMap.ContainsKey(funcName)) | if(_opMap.ContainsKey(funcName)) | ||||
| { | { | ||||
| var opDef = _opMap[funcName]; | var opDef = _opMap[funcName]; | ||||
| _g.AppendFunction(opDef, sb); | |||||
| // write the descriptions. | |||||
| _dg.AppendDescription(opDef, sb); | |||||
| // write the function body. | |||||
| _fg.AppendFunction(opDef, sb); | |||||
| } | } | ||||
| else if (funcName.StartsWith("_")) | else if (funcName.StartsWith("_")) | ||||
| { | { | ||||
| var opDef = _opMap[funcName.Substring(1)]; | var opDef = _opMap[funcName.Substring(1)]; | ||||
| _g.AppendFunction(opDef, sb); | |||||
| _fg.AppendFunction(opDef, sb); | |||||
| } | } | ||||
| } | } | ||||
| @@ -69,12 +76,5 @@ namespace Tensorflow.CodeGen | |||||
| File.WriteAllText(fullFilePath, sb.ToString()); | File.WriteAllText(fullFilePath, sb.ToString()); | ||||
| } | } | ||||
| } | } | ||||
| private OpList ReadAllOpDefs(string path) | |||||
| { | |||||
| var text = File.ReadAllText(path); | |||||
| var opDefs = OpList.Parser.ParseText(text); | |||||
| return opDefs; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,10 +5,9 @@ using System.Text; | |||||
| using System.Xml.Linq; | using System.Xml.Linq; | ||||
| using Tensorflow.CodeGen; | using Tensorflow.CodeGen; | ||||
| //Console.WriteLine(Utils.ConvertToUnderscore("LRN")); | |||||
| GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | ||||
| @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | ||||
| @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | |||||
| @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | ||||
| writer.WriteAll(); | writer.WriteAll(); | ||||
| @@ -9,11 +9,11 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.100.5" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | <ProjectReference Include="..\..\protobuf.Text\src\protobuf.Text\protobuf.Text.csproj" /> | ||||
| <ProjectReference Include="..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Protobuf.Text; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Reflection.Metadata.Ecma335; | using System.Reflection.Metadata.Ecma335; | ||||
| @@ -51,5 +52,201 @@ namespace Tensorflow.CodeGen | |||||
| return result.ToString(); | return result.ToString(); | ||||
| } | } | ||||
| public static OpList ReadAllOpDefs(string path) | |||||
| { | |||||
| var text = File.ReadAllText(path); | |||||
| var opDefs = OpList.Parser.ParseText(text); | |||||
| return opDefs; | |||||
| } | |||||
| // name, type string, default value | |||||
| public static List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary<string, string> dynamicDefaultValues) | |||||
| { | |||||
| dynamicDefaultValues = new(); | |||||
| List<(string, string, string)> res = new(); | |||||
| foreach (var attr in op.Attr) | |||||
| { | |||||
| if (attr.Type == "type") | |||||
| { | |||||
| bool found = op.InputArg.Any(x => x.TypeAttr == attr.Name); | |||||
| if (!found) | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
| { | |||||
| string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); | |||||
| string enumPath = typeof(TF_DataType).Name + "." + name; | |||||
| res.Add((attr.Name, "TF_DataType", enumPath)); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "TF_DataType", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "int") | |||||
| { | |||||
| if (op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) | |||||
| { | |||||
| res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "int", "0")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "float") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) | |||||
| { | |||||
| res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "float", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "string") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
| { | |||||
| res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "string", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "bool") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) | |||||
| { | |||||
| res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "bool", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "shape") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) | |||||
| { | |||||
| if (attr.DefaultValue.Shape.UnknownRank) | |||||
| { | |||||
| res.Add((attr.Name, "Shape", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| Shape shape = new Shape(attr.DefaultValue.Shape); | |||||
| string expression = $"new Shape({string.Join(", ", shape.dims)})"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "Shape", $"null")); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "Shape", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(type)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||||
| { | |||||
| List<TF_DataType> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.Type) | |||||
| { | |||||
| values.Add(value.as_tf_dtype()); | |||||
| } | |||||
| string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "TF_DataType[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(shape)") | |||||
| { | |||||
| res.Add((attr.Name, "Shape[]", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "list(string)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
| { | |||||
| List<string> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.S) | |||||
| { | |||||
| values.Add(value.ToStringUtf8()); | |||||
| } | |||||
| string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "string[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(int)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| List<int> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.I) | |||||
| { | |||||
| values.Add((int)value); | |||||
| } | |||||
| string expression = "new int[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "int[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "int[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "list(float)") | |||||
| { | |||||
| if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
| { | |||||
| List<float> values = new(); | |||||
| foreach (var value in attr.DefaultValue.List.F) | |||||
| { | |||||
| values.Add(value); | |||||
| } | |||||
| string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; | |||||
| dynamicDefaultValues[attr.Name] = expression; | |||||
| res.Add((attr.Name, "float[]", $"null")); | |||||
| } | |||||
| else | |||||
| { | |||||
| res.Add((attr.Name, "float[]", "NOVALUE")); | |||||
| } | |||||
| } | |||||
| else if (attr.Type == "func") | |||||
| { | |||||
| res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "list(func)") | |||||
| { | |||||
| res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||||
| } | |||||
| else if (attr.Type == "tensor") | |||||
| { | |||||
| res.Add((attr.Name, "TensorProto", "NOVALUE")); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||