Files
NebulaLauncher/Nebula.Shared/Services/AssemblyService.cs

212 lines
6.5 KiB
C#
Raw Normal View History

2025-01-05 17:05:23 +03:00
using System.Diagnostics.CodeAnalysis;
2024-12-27 19:15:33 +03:00
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;
2025-01-05 17:05:23 +03:00
using Nebula.Shared.FileApis;
using Nebula.Shared.Services.Logging;
2024-12-27 19:15:33 +03:00
using Robust.LoaderApi;
2025-01-14 22:10:16 +03:00
using SharpZstd.Interop;
2024-12-27 19:15:33 +03:00
2025-01-05 17:05:23 +03:00
namespace Nebula.Shared.Services;
2024-12-27 19:15:33 +03:00
[ServiceRegister]
public class AssemblyService
{
2025-07-13 10:08:07 +03:00
private readonly Dictionary<string, Assembly> _assemblyCache = new();
private readonly ILogger _logger;
2024-12-27 19:15:33 +03:00
2025-01-14 22:10:16 +03:00
private readonly HashSet<string> _resolvingAssemblies = new();
2025-07-13 10:08:07 +03:00
private List<AssemblyApi> _mountedApis = [];
public Action<Assembly>? OnAssemblyLoaded;
public IReadOnlyList<Assembly> Assemblies => _assemblyCache.Values.ToList().AsReadOnly();
2024-12-27 19:15:33 +03:00
public AssemblyService(DebugService debugService)
{
_logger = debugService.GetLogger(this);
2025-07-13 10:08:07 +03:00
AssemblyLoadContext.Default.ResolvingUnmanagedDll += LoadContextOnResolvingUnmanaged;
AssemblyLoadContext.Default.Resolving += (context, name) => OnAssemblyResolving(context, name);
2025-01-14 22:10:16 +03:00
ZstdImportResolver.ResolveLibrary += (name, assembly1, path) =>
2025-01-07 17:01:00 +03:00
{
if (name.Equals("SharpZstd.Native"))
{
_logger.Debug("RESOLVING SHARPZSTD THINK: " + name + " " + path);
2025-01-14 22:10:16 +03:00
GetRuntimeInfo(out var platform, out var architecture, out var extension);
var fileName = GetDllName(platform, architecture, extension);
2025-01-07 17:01:00 +03:00
2025-01-14 22:10:16 +03:00
if (NativeLibrary.TryLoad(fileName, assembly1, path, out var nativeLibrary)) return nativeLibrary;
2025-01-07 17:01:00 +03:00
}
2025-01-14 22:10:16 +03:00
2025-01-07 17:01:00 +03:00
return IntPtr.Zero;
};
2024-12-27 19:15:33 +03:00
}
2025-07-13 10:08:07 +03:00
private Assembly? OnAssemblyResolving(AssemblyLoadContext context, AssemblyName name)
{
if (_resolvingAssemblies.Contains(name.FullName))
{
_logger.Debug($"Already resolving {name.Name}, skipping.");
return null; // Prevent recursive resolution
}
Assembly? assembly;
if (_assemblyCache.TryGetValue(name.Name ?? "", out assembly))
{
return assembly;
}
foreach (var api in _mountedApis)
{
if((assembly = OnAssemblyResolving(context, name, api)) != null)
return assembly;
}
return null;
}
2024-12-27 19:15:33 +03:00
2025-01-07 17:01:00 +03:00
public AssemblyApi Mount(IFileApi fileApi)
2024-12-27 19:15:33 +03:00
{
var asmApi = new AssemblyApi(fileApi);
2025-07-13 10:08:07 +03:00
_mountedApis.Add(asmApi);
2024-12-27 19:15:33 +03:00
return asmApi;
}
public bool TryGetLoader(Assembly clientAssembly, [NotNullWhen(true)] out ILoaderEntryPoint? loader)
{
loader = null;
2025-07-13 10:08:07 +03:00
2024-12-27 19:15:33 +03:00
var attrib = clientAssembly.GetCustomAttribute<LoaderEntryPointAttribute>();
if (attrib == null)
{
Console.WriteLine("No LoaderEntryPointAttribute found on Robust.Client assembly!");
return false;
}
var type = attrib.LoaderEntryPointType;
if (!type.IsAssignableTo(typeof(ILoaderEntryPoint)))
{
Console.WriteLine("Loader type '{0}' does not implement ILoaderEntryPoint!", type);
return false;
}
loader = (ILoaderEntryPoint)Activator.CreateInstance(type)!;
return true;
}
public bool TryOpenAssembly(string name, AssemblyApi assemblyApi, [NotNullWhen(true)] out Assembly? assembly)
{
if (!TryOpenAssemblyStream(name, assemblyApi, out var asm, out var pdb))
{
assembly = null;
return false;
}
assembly = AssemblyLoadContext.Default.LoadFromStream(asm, pdb);
_logger.Log("LOADED ASSEMBLY " + name);
2025-07-13 10:08:07 +03:00
if (_assemblyCache.TryAdd(name, assembly))
{
OnAssemblyLoaded?.Invoke(assembly);
}
2024-12-27 19:15:33 +03:00
asm.Dispose();
pdb?.Dispose();
return true;
}
public bool TryOpenAssemblyStream(string name, AssemblyApi assemblyApi, [NotNullWhen(true)] out Stream? asm,
out Stream? pdb)
{
asm = null;
pdb = null;
if (!assemblyApi.TryOpen($"{name}.dll", out asm))
return false;
assemblyApi.TryOpen($"{name}.pdb", out pdb);
return true;
}
2025-01-07 17:01:00 +03:00
private Assembly? OnAssemblyResolving(AssemblyLoadContext context, AssemblyName name, AssemblyApi assemblyApi)
2024-12-27 19:15:33 +03:00
{
2025-07-13 10:08:07 +03:00
lock (_resolvingAssemblies)
2025-01-07 17:01:00 +03:00
{
2025-07-13 10:08:07 +03:00
try
{
_resolvingAssemblies.Add(name.FullName);
_logger.Debug($"Resolving assembly from FileAPI: {name.Name}");
return TryOpenAssembly(name.Name!, assemblyApi, out var assembly) ? assembly : null;
}
finally
{
_resolvingAssemblies.Remove(name.FullName);
}
2025-01-07 17:01:00 +03:00
}
2024-12-27 19:15:33 +03:00
}
private IntPtr LoadContextOnResolvingUnmanaged(Assembly assembly, string unmanaged)
{
var ourDir = Path.GetDirectoryName(typeof(AssemblyApi).Assembly.Location);
var a = Path.Combine(ourDir!, unmanaged);
_logger.Debug($"Loading dll lib: {a}");
2024-12-27 19:15:33 +03:00
if (NativeLibrary.TryLoad(a, out var handle))
return handle;
2025-01-14 22:10:16 +03:00
_logger.Error("Loading dll error! Not found");
2024-12-27 19:15:33 +03:00
return IntPtr.Zero;
}
2025-01-14 22:10:16 +03:00
2025-01-07 17:01:00 +03:00
public static string GetDllName(
string platform,
string architecture,
string extension)
{
2025-01-14 22:10:16 +03:00
var name = $"SharpZstd.Native.{extension}";
2025-01-07 17:01:00 +03:00
return name;
}
public static void GetRuntimeInfo(
out string platform,
out string architecture,
out string extension)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = "win";
extension = "dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = "linux";
extension = "so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = "osx";
extension = "dylib";
}
else
{
platform = "linux";
extension = "so";
}
if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
architecture = "x64";
else if (RuntimeInformation.ProcessArchitecture == Architecture.X86)
architecture = "x86";
else if (RuntimeInformation.ProcessArchitecture == Architecture.Arm)
architecture = "arm";
else if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
architecture = "arm64";
else
throw new PlatformNotSupportedException("Unsupported process architecture.");
}
2024-12-27 19:15:33 +03:00
}