diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs new file mode 100644 index 00000000..fca6e4f3 --- /dev/null +++ b/LLama/Grammar/GrammarParser.cs @@ -0,0 +1,328 @@ +using LLama.Native; +using System; +using System.Collections.Generic; + +namespace LLama.Grammar +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.cpp + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + internal class GrammarParser + { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + public Tuple> DecodeUTF8(ReadOnlyMemory src) + { + ReadOnlySpan span = src.Span; + int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + + byte firstByte = (byte)span[0]; + byte highbits = (byte)(firstByte >> 4); + int len = lookup[highbits]; + byte mask = (byte)((1 << (8 - len)) - 1); + uint value = (uint)(firstByte & mask); + + int end = len; + int pos = 1; + + for (; pos < end && pos < src.Length; pos++) + { + value = (uint)((value << 6) + ((byte)span[pos] & 0x3F)); + } + + ReadOnlyMemory nextSpan = src.Slice(pos); + + return new Tuple>(value, nextSpan); + } + + public uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = src.Slice(0, len).ToString(); + + if (state.SymbolIds.TryGetValue(key, out uint existingId)) + { + return existingId; + } + else + { + state.SymbolIds[key] = nextId; + return nextId; + } + } + + public uint GenerateSymbolId(ParseState state, string baseName) + { + uint nextId = (uint)state.SymbolIds.Count; + string key = $"{baseName}_{nextId}"; + state.SymbolIds[key] = nextId; + return nextId; + } + + public void AddRule(ParseState state, uint ruleId, List rule) + { + while (state.Rules.Count <= ruleId) + { + state.Rules.Add(new List()); + } + + state.Rules[(int)ruleId] = rule; + } + + public bool IsWordChar(char c) + { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + public Tuple> ParseHex(ReadOnlyMemory src, int size) + { + int pos = 0; + int end = size; + uint value = 0; + + ReadOnlySpan srcSpan = src.Span; + + for (; pos < end && pos < src.Length; pos++) + { + value <<= 4; + char c = srcSpan[pos]; + if ('a' <= c && c <= 'f') + { + value += (uint)(c - 'a' + 10); + } + else if ('A' <= c && c <= 'F') + { + value += (uint)(c - 'A' + 10); + } + else if ('0' <= c && c <= '9') + { + value += (uint)(c - '0'); + } + else + { + break; + } + } + + if (pos != end) + { + throw new InvalidOperationException($"Expecting {size} hex chars at {src.ToString()}"); + } + + return new Tuple>(value, src.Slice(pos)); + } + + public ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + { + int pos = 0; + while (pos < src.Length && + (src[pos] == ' ' || src[pos] == '\t' || src[pos] == '#' || + (newlineOk && (src[pos] == '\r' || src[pos] == '\n')))) + { + if (src[pos] == '#') + { + while (pos < src.Length && src[pos] != '\r' && src[pos] != '\n') + { + pos++; + } + } + else + { + pos++; + } + } + return src.Slice(pos); + } + + public ReadOnlySpan ParseName(ReadOnlySpan src) + { + int pos = 0; + while (pos < src.Length && IsWordChar(src[pos])) + { + pos++; + } + if (pos == 0) + { + throw new InvalidOperationException($"Expecting name at {src.ToString()}"); + } + return src.Slice(pos); + } + + public Tuple> ParseChar(ReadOnlyMemory src) + { + ReadOnlySpan span = src.Span; + + if (span[0] == '\\') + { + switch (span[1]) + { + case 'x': + return ParseHex(src.Slice(2), 2); + case 'u': + return ParseHex(src.Slice(2), 4); + case 'U': + return ParseHex(src.Slice(2), 8); + case 't': + return new Tuple>('\t', src.Slice(2)); + case 'r': + return new Tuple>('\r', src.Slice(2)); + case 'n': + return new Tuple>('\n', src.Slice(2)); + case '\\': + case '"': + case '[': + case ']': + return new Tuple>(span[1], src.Slice(2)); + default: + throw new Exception("Unknown escape at " + src.ToString()); + } + } + else if (!span.IsEmpty) + { + return DecodeUTF8(src); + } + + throw new Exception("Unexpected end of input"); + } + + public ReadOnlySpan ParseSequence( + ref ParseState state, + ReadOnlyMemory src, + string ruleName, + List outElements, + bool isNested) + { + int lastSymStart = outElements.Count; + var pos = src.Span; + + while (!pos.IsEmpty) + { + if (pos[0] == '"') // literal string + { + pos = pos.Slice(1); + lastSymStart = outElements.Count; + + while (pos[0] != '"') + { + var charPair = ParseChar(src); + pos = charPair.Item2.Span; + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair.Item1 }); + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '[') // char range(s) + { + pos = pos.Slice(1); + var startType = LLamaGrammarElementType.CHAR; + + if (pos[0] == '^') + { + pos = pos.Slice(1); + startType = LLamaGrammarElementType.CHAR_NOT; + } + + lastSymStart = outElements.Count; + + while (pos[0] != ']') + { + var charPair = ParseChar(src); + pos = charPair.Item2.Span; + var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; + + outElements.Add(new LLamaGrammarElement { Type = type, Value = charPair.Item1 }); + + if (pos[0] == '-' && pos[1] != ']') + { + var endCharPair = ParseChar(src.Slice(1)); + pos = endCharPair.Item2.Span; + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR_RNG_UPPER, Value = endCharPair.Item1 }); + } + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (IsWordChar(pos[0])) // rule reference + { + var nameEnd = ParseName(pos); + uint refRuleId = GetSymbolId(state, pos, nameEnd.Length); + pos = ParseSpace(nameEnd, isNested); + lastSymStart = outElements.Count; + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = refRuleId }); + } + else if (pos[0] == '(') // grouping + { + // parse nested alternates into synthesized rule + pos = ParseSpace(pos.Slice(1), true); + uint subRuleId = GenerateSymbolId(state, ruleName); + pos = ParseAlternates(state, pos, ruleName, subRuleId, true); + lastSymStart = outElements.Count; + // output reference to synthesized rule + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + if (pos[0] != ')') + { + throw new Exception($"Expecting ')' at {new string(pos.ToArray())}"); + } + pos = ParseSpace(pos.Slice(1), isNested); + } + else if (pos[0] == '*' || pos[0] == '+' || pos[0] == '?') // repetition operator + { + if (lastSymStart == outElements.Count) + { + throw new Exception($"Expecting preceding item to */+/? at {new string(pos.ToArray())}"); + } + + // apply transformation to previous symbol (lastSymStart to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint subRuleId = GenerateSymbolId(state, ruleName); + + List subRule = new List(); + + // add preceding symbol to generated rule + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + + if (pos[0] == '*' || pos[0] == '+') + { + // cause generated rule to recurse + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + } + + // mark start of alternate def + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); + + if (pos[0] == '+') + { + // add preceding symbol as alternate only for '+' (otherwise empty) + subRule.AddRange(outElements.GetRange(lastSymStart, outElements.Count - lastSymStart)); + } + + subRule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.END, Value = 0 }); + + AddRule(state, subRuleId, subRule); + + // in original rule, replace previous symbol with reference to generated rule + outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart); + outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.RULE_REF, Value = subRuleId }); + + pos = ParseSpace(pos.Slice(1), isNested); + + } + else + { + break; + } + } + + return pos; + } + + public ReadOnlySpan ParseAlternates(ParseState state, ReadOnlySpan pos, string ruleName, uint subRuleId, bool v) + { + throw new NotImplementedException(); + } + } +} diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs new file mode 100644 index 00000000..bd5eaf41 --- /dev/null +++ b/LLama/Grammar/ParseState.cs @@ -0,0 +1,26 @@ +using LLama.Native; +using System; +using System.Collections.Generic; + +namespace LLama.Grammar +{ + /// + /// Source: + /// https://github.com/ggerganov/llama.cpp/blob/6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9/common/grammar-parser.h + /// + /// The commit hash from URL is the actual commit hash that reflects current C# code. + /// + internal class ParseState + { + public Dictionary SymbolIds { get; } = new Dictionary(); + public List> Rules { get; } = new List>(); + + public IEnumerable> CRules() + { + foreach (var rule in Rules) + { + yield return rule; + } + } + } +}