@@ -3,7 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection ;
using System.Threading ;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -14,94 +14,136 @@ namespace Discord.Net.SourceGenerators.Serialization
{
public void Execute(GeneratorExecutionContext context)
{
if (!context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
"build_property.DiscordNet_SerializationGenerator_OptionsTypeNamespace",
out var serializerOptionsNamespace))
throw new InvalidOperationException(
"Missing output namespace. Set DiscordNet_SerializationGenerator_OptionsTypeNamespace in your project file.");
bool searchThroughReferencedAssemblies =
context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(
"build_property.DiscordNet_SerializationGenerator_SearchThroughReferencedAssemblies",
out var _);
var generateSerializerAttribute = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.GenerateSerializerAttribute");
var discriminatedUnionSymbol = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.DiscriminatedUnionAttribute");
var discriminatedUnionMemberSymbol = context.Compilation
.GetTypeByMetadataName(
"Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");
Debug.Assert(generateSerializerAttribute != null);
Debug.Assert(discriminatedUnionSymbol != null);
Debug.Assert(discriminatedUnionMemberSymbol != null);
Debug.Assert(context.SyntaxContextReceiver != null);
var receiver = (SyntaxReceiver)context.SyntaxContextReceiver!;
var converters = new List<string>();
var symbolsToBuild = receiver.GetSerializedTypes(
context.Compilation);
foreach (var @class in receiver.Classes)
if (searchThroughReferencedAssembli es)
{
var semanticModel = context.Compilation.GetSemanticModel(
@class.SyntaxTree);
var visitor = new VisibleTypeVisitor(context.CancellationToken);
foreach (var module in context.Compilation.Assembly.Modules)
foreach (var reference in module.ReferencedAssemblySymbols)
visitor.Visit(reference);
if (semanticModel.GetDeclaredSymbol(@class) is
not INamedTypeSymbol classSymbol)
throw new InvalidOperationException(
"Could not find named type symbol for " +
$"{@class.Identifier}");
symbolsToBuild = symbolsToBuild
.Concat(visitor.GetVisibleTypes());
}
context.AddSource(
$"Converters.{classSymbol.Name}",
GenerateConverter(classSymbol));
var types = SerializedTypeUtils.BuildTypeTrees(
generateSerializerAttribute: generateSerializerAttribute!,
discriminatedUnionSymbol: discriminatedUnionSymbol!,
discriminatedUnionMemberSymbol: discriminatedUnionMemberSymbol!,
symbolsToBuild: symbolsToBuild);
converters.Add($"{classSymbol.Name}Converter");
foreach (var type in types)
{
context.AddSource($"Converters.{type.ConverterTypeName}",
type.GenerateSourceCode(serializerOptionsNamespace));
if (type is DiscriminatedUnionSerializedType duDeclaration)
foreach (var member in duDeclaration.Members)
context.AddSource(
$"Converters.{type.ConverterTypeName}.{member.ConverterTypeName}",
member.GenerateSourceCode(serializerOptionsNamespace));
}
context.AddSource("SerializerOptions.Complete",
GenerateSerializerOptionsSourceCode(converters));
context.AddSource("SerializerOptions",
GenerateSerializerOptionsSourceCode(
serializerOptionsNamespace, types));
}
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(PostInitialize);
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
}
public static void PostInitialize(
GeneratorPostInitializationContext context)
=> context.AddSource("SerializerOptions.Template",
GenerateSerializerOptionsTemplateSourceCode());
=> context.RegisterForSyntaxNotifications(
() => new SyntaxReceiver());
internal class SyntaxReceiver : ISyntaxContextReceiver
private class SyntaxReceiver : ISyntaxContextReceiver
{
public List<ClassDeclarationSyntax> Classes { get; } = new();
private readonly Dictionary<string, INamedTypeSymbol> _interestingAttributes
= new();
private readonly List<SyntaxNode> _classes;
public void OnVisitSyntaxNode(GeneratorSyntaxContext context )
public SyntaxReceiver()
{
_ = GetOrAddAttribute(_interestingAttributes,
context.SemanticModel,
"Discord.Net.Serialization.DiscriminatedUnionAttribute");
_ = GetOrAddAttribute(_interestingAttributes,
context.SemanticModel,
"Discord.Net.Serialization.DiscriminatedUnionMemberAttribute");
_classes = new();
}
if (context.Node is ClassDeclarationSyntax classDecl
&& classDecl.AttributeLists is
SyntaxList<AttributeListSyntax> attrList
&& attrList.Any(
list => list.Attributes
.Any(a => IsInterestingAttribute(a,
context.SemanticModel,
_interestingAttributes.Values))))
public IEnumerable<INamedTypeSymbol> GetSerializedTypes(
Compilation compilation)
{
foreach (var @class in _classes)
{
Classes.Add(classDecl);
var semanticModel = compilation.GetSemanticModel(
@class.SyntaxTree);
if (semanticModel.GetDeclaredSymbol(@class) is
INamedTypeSymbol classSymbol)
yield return classSymbol;
}
}
private static INamedTypeSymbol GetOrAddAttribute(
Dictionary<string, INamedTypeSymbol> cache,
SemanticModel model, string name )
private INamedTypeSymbol? _generateSerializerAttributeSymbol;
public void OnVisitSyntaxNode(GeneratorSyntaxContext context )
{
if (!cache.TryGetValue(name, out var type))
_generateSerializerAttributeSymbol ??=
context.SemanticModel.Compilation.GetTypeByMetadataName(
"Discord.Net.Serialization.GenerateSerializerAttribute");
Debug.Assert(_generateSerializerAttributeSymbol != null);
if (context.Node is ClassDeclarationSyntax classDeclaration
&& classDeclaration.AttributeLists is
SyntaxList<AttributeListSyntax> classAttributeLists
&& classAttributeLists.Any(
list => list.Attributes.Any(
n => IsAttribute(n, context.SemanticModel,
_generateSerializerAttributeSymbol!))))
{
type = model.Compilation.GetTypeByMetadataName(name);
Debug.Assert(type != null);
cache.Add(name, type!);
_classes.Add(classDeclaration);
}
else if (context.Node is RecordDeclarationSyntax recordDeclaration
&& recordDeclaration.AttributeLists is
SyntaxList<AttributeListSyntax> recordAttributeLists
&& recordAttributeLists.Any(
list => list.Attributes.Any(
n => IsAttribute(n, context.SemanticModel,
_generateSerializerAttributeSymbol!))))
{
_classes.Add(recordDeclaration);
}
return type!;
}
private static bool IsInterestingAttribute(
AttributeSyntax attribute, SemanticModel model,
IEnumerable<INamedTypeSymbol> interestingAttributes)
{
var typeInfo = model.GetTypeInfo(attribute.Name);
static bool IsAttribute(AttributeSyntax attribute,
SemanticModel model, INamedTypeSymbol expected)
{
var typeInfo = model.GetTypeInfo(attribute.Name);
return interestingAttributes.Any (
x => SymbolEqualityComparer.Default
.Equals(typeInfo.Type, x));
return SymbolEqualityComparer.Default.Equals(
typeInfo.Type, expected);
}
}
}
}