diff --git a/Nebula.Launcher/Nebula.Launcher.csproj b/Nebula.Launcher/Nebula.Launcher.csproj index b1f9424..9e75467 100644 --- a/Nebula.Launcher/Nebula.Launcher.csproj +++ b/Nebula.Launcher/Nebula.Launcher.csproj @@ -32,6 +32,7 @@ + diff --git a/Nebula.Runner/Nebula.Runner.csproj b/Nebula.Runner/Nebula.Runner.csproj index 321412a..2cf5a68 100644 --- a/Nebula.Runner/Nebula.Runner.csproj +++ b/Nebula.Runner/Nebula.Runner.csproj @@ -14,6 +14,7 @@ + diff --git a/Nebula.Runner/Services/HarmonyService.cs b/Nebula.Runner/Services/HarmonyService.cs index 691955e..5ff38f0 100644 --- a/Nebula.Runner/Services/HarmonyService.cs +++ b/Nebula.Runner/Services/HarmonyService.cs @@ -1,11 +1,10 @@ -using System.Data; using HarmonyLib; using Nebula.Shared; namespace Nebula.Runner.Services; [ServiceRegister] -public class HarmonyService(ReflectionService reflectionService) +public class HarmonyService { private HarmonyInstance? _instance; @@ -25,21 +24,6 @@ public class HarmonyService(ReflectionService reflectionService) throw new Exception(); _instance = new HarmonyInstance(); - UnShittyWizard(); - } - - /// - /// Я помню пенис большой,Я помню пенис большой, Я помню пенис большой, я помню.... - /// - private void UnShittyWizard() - { - var method = reflectionService.GetType("Robust.Client.GameController").TypeInitializer; - _instance!.Harmony.Patch(method, new HarmonyMethod(Prefix)); - } - - static bool Prefix() - { - return false; } } diff --git a/Nebula.Runner/Services/ReflectionService.cs b/Nebula.Runner/Services/ReflectionService.cs index d1432e2..cd2eff8 100644 --- a/Nebula.Runner/Services/ReflectionService.cs +++ b/Nebula.Runner/Services/ReflectionService.cs @@ -1,33 +1,29 @@ using System.Reflection; using Nebula.Shared; -using Nebula.Shared.FileApis; using Nebula.Shared.Services; namespace Nebula.Runner.Services; [ServiceRegister] -public class ReflectionService(AssemblyService assemblyService) +public class ReflectionService { - private Dictionary _typeCache = new(); + private readonly Dictionary _typeCache = new(); + + public ReflectionService(AssemblyService assemblyService) + { + assemblyService.OnAssemblyLoaded += OnAssemblyLoaded; + } + + private void OnAssemblyLoaded(Assembly obj) + { + RegisterAssembly(obj); + } public void RegisterAssembly(Assembly robustAssembly) { _typeCache.Add(robustAssembly.GetName().Name!, robustAssembly); } - - public void RegisterRobustAssemblies(AssemblyApi engine) - { - RegisterAssembly(GetRobustAssembly("Robust.Shared", engine)); - RegisterAssembly(GetRobustAssembly("Robust.Client", engine)); - } - - private Assembly GetRobustAssembly(string assemblyName, AssemblyApi engine) - { - if(!assemblyService.TryOpenAssembly(assemblyName, engine, out var assembly)) - throw new Exception($"Unable to locate {assemblyName}.dll in engine build!"); - return assembly; - } - + public Type? GetTypeImp(string name) { foreach (var (prefix,assembly) in _typeCache) @@ -51,7 +47,7 @@ public class ReflectionService(AssemblyService assemblyService) : assembly.GetType(name)!; } - private string ExtrackPrefix(string path) + public string ExtrackPrefix(string path) { var sp = path.Split("."); return sp[0] + "." + sp[1]; diff --git a/Nebula.Runner/Services/RunnerService.cs b/Nebula.Runner/Services/RunnerService.cs index 3b1d711..777b843 100644 --- a/Nebula.Runner/Services/RunnerService.cs +++ b/Nebula.Runner/Services/RunnerService.cs @@ -1,6 +1,5 @@ using System.Globalization; using System.Reflection; -using System.Reflection.Emit; using HarmonyLib; using Nebula.Shared; using Nebula.Shared.Models; @@ -18,9 +17,10 @@ public sealed class RunnerService( EngineService engineService, AssemblyService assemblyService, ReflectionService reflectionService, - HarmonyService harmonyService) + HarmonyService harmonyService, + ScriptService scriptService) { - private ILogger _logger = debugService.GetLogger("RunnerService"); + private readonly ILogger _logger = debugService.GetLogger("RunnerService"); private bool MetricEnabled = false; //TODO: ADD METRIC THINKS LATER public async Task Run(string[] runArgs, RobustBuildInfo buildInfo, IRedialApi redialApi, @@ -58,6 +58,24 @@ public sealed class RunnerService( var args = new MainArgs(runArgs, engine, redialApi, extraMounts); + + var assemblyManifest = hashApi.Manifest.Where(p => + p.Key.StartsWith("Assemblies/")) + .Select(p => + { + return p.Value with { Path = Path.GetFileNameWithoutExtension(p.Key) }; + }).ToList(); + + var assembliesHash = contentService.CreateHashApi(assemblyManifest); + + var contentAssemblyApi = assemblyService.Mount(assembliesHash); + + foreach (var file in contentAssemblyApi.AllFiles.Where(p => Path.GetExtension(p) == ".dll")) + { + var newExt = Path.GetFileNameWithoutExtension(file); + if(!assemblyService.TryOpenAssembly(newExt, contentAssemblyApi, out _)) throw new Exception("Assembly not found: " + newExt); + } + if (!assemblyService.TryOpenAssembly(varService.GetConfigValue(CurrentConVar.RobustAssemblyName)!, engine, out var clientAssembly)) throw new Exception("Unable to locate Robust.Client.dll in engine build!"); @@ -68,7 +86,6 @@ public sealed class RunnerService( if(!assemblyService.TryOpenAssembly("Prometheus.NetStandard", engine, out var prometheusAssembly)) return; - reflectionService.RegisterRobustAssemblies(engine); harmonyService.CreateInstance(); IDisposable? metricServer = null; @@ -78,12 +95,18 @@ public sealed class RunnerService( MetricsEnabledPatcher.ApplyPatch(reflectionService, harmonyService); metricServer = RunHelper.RunMetric(prometheusAssembly); } - + + scriptService.LoadScripts(); await Task.Run(() => loader.Main(args), cancellationToken); metricServer?.Dispose(); } + + private void CacheAssembly() + { + + } } public static class MetricsEnabledPatcher diff --git a/Nebula.Runner/Services/ScriptService.cs b/Nebula.Runner/Services/ScriptService.cs new file mode 100644 index 0000000..3295455 --- /dev/null +++ b/Nebula.Runner/Services/ScriptService.cs @@ -0,0 +1,181 @@ +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; +using HarmonyLib; +using Nebula.Shared; +using Nebula.Shared.FileApis; +using Nebula.Shared.Services; +using NLua; + +namespace Nebula.Runner.Services; + +[ServiceRegister] +public class ScriptService +{ + private readonly HarmonyService _harmonyService; + private readonly ReflectionService _reflectionService; + private readonly AssemblyService _assemblyService; + + private readonly FileApi _scriptFileApi; + + private static Dictionary _scriptCache = []; + private static Dictionary _assemblyLoadingQuery = []; + + public ScriptService(HarmonyService harmonyService, ReflectionService reflectionService, FileService fileService, AssemblyService assemblyService) + { + _harmonyService = harmonyService; + _reflectionService = reflectionService; + _assemblyService = assemblyService; + + _scriptFileApi = fileService.CreateFileApi("scripts"); + _assemblyService.OnAssemblyLoaded += OnAssemblyLoaded; + } + + private void OnAssemblyLoaded(Assembly obj) + { + var objName = obj.GetName().Name ?? string.Empty; + if (!_assemblyLoadingQuery.TryGetValue(objName, out var a)) return; + Console.WriteLine("Inject assembly: " + objName); + a(); + _assemblyLoadingQuery.Remove(objName); + } + + public void LoadScripts() + { + Console.WriteLine("Loading scripts... " + _scriptFileApi.EnumerateDirectories("").Count()); + foreach (var dir in _scriptFileApi.EnumerateDirectories("")) + { + LoadScript(dir); + } + } + + public void LoadScript(string name) + { + Console.WriteLine($"Reading script {name}"); + var manifests = ReadManifest(name); + + foreach (var entry in manifests) + { + if (entry.TypeInitializer.HasValue) LoadTypeInitializer(entry.TypeInitializer.Value, name); + if (entry.Method.HasValue) LoadMethod(entry.Method.Value, name); + } + } + + private void LoadTypeInitializer(ScriptMethodInjectItem item, string name) + { + Console.WriteLine($"Loading Initializer injection {name}..."); + var assemblyName = _reflectionService.ExtrackPrefix(item.Method.Class); + + if (!_assemblyService.Assemblies.Select(a => a.GetName().Name).Contains(assemblyName)) + { + _assemblyLoadingQuery.Add(assemblyName, () => LoadTypeInitializer(item, name)); + return; + } + + var targetType = _reflectionService.GetType(item.Method.Class); + var method = targetType.TypeInitializer; + InitialiseShared(method!, name, item); + } + + private void LoadMethod(ScriptMethodInjectItem item, string name) + { + Console.WriteLine($"Loading method injection {name}..."); + var assemblyName = _reflectionService.ExtrackPrefix(item.Method.Class); + + if (!_assemblyService.Assemblies.Select(a => a.GetName().Name).Contains(assemblyName)) + { + _assemblyLoadingQuery.Add(assemblyName, () => LoadMethod(item, name)); + return; + } + + var targetType = _reflectionService.GetType(item.Method.Class); + var method = targetType.GetMethod(item.Method.Method, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); + InitialiseShared(method!, name, item); + } + + private void InitialiseShared(MethodBase method, string scriptName, ScriptMethodInjectItem item) + { + var scriptCode = File.ReadAllText(Path.Combine(_scriptFileApi.RootPath, scriptName, item.Script.LuaFile)); + + var methodInfo = method as MethodInfo; + HarmonyMethod dynamicPatch; + + if (methodInfo == null || methodInfo.ReturnType == typeof(void)) + dynamicPatch = new HarmonyMethod(typeof(ScriptService).GetMethod(nameof(LuaPrefix), BindingFlags.Static | BindingFlags.NonPublic)); + else + dynamicPatch = new HarmonyMethod(typeof(ScriptService).GetMethod(nameof(LuaPrefixResult), BindingFlags.Static | BindingFlags.NonPublic)); + + _scriptCache[method] = new ScriptManifestDict(scriptCode, item); + + _harmonyService.Instance.Harmony.Patch(method, prefix: dynamicPatch); + Console.WriteLine($"Injected {scriptName}"); + } + + private ScriptEntry[] ReadManifest(string scriptName) + { + if(!_scriptFileApi.TryOpen(Path.Join(scriptName, "MANIFEST.json"), out var stream)) + throw new FileNotFoundException(Path.Join(scriptName, "MANIFEST.json") + " not found manifest!"); + + return JsonSerializer.Deserialize(stream) ?? []; + } + + private static bool LuaPrefix(MethodBase __originalMethod, object __instance) + { + if (!_scriptCache.TryGetValue(__originalMethod, out var luaCode)) + return true; + + using var lua = new Lua(); + + lua["this"] = __instance; + + var results = lua.DoString(luaCode.Code); + + if (results is { Length: > 0 } && results[0] is bool b) + return b; + + return luaCode.ScriptMethodInjectItem.ContinueAfterInject; + } + + private static bool LuaPrefixResult(MethodBase __originalMethod, object __instance, ref object __result) + { + if (!_scriptCache.TryGetValue(__originalMethod, out var luaCode)) + return true; + + using var lua = new Lua(); + + lua["this"] = __instance; + lua["result"] = __result; + + var results = lua.DoString(luaCode.Code); + + if (lua["result"] != null) + __result = lua["result"]; + + if (results is { Length: > 0 } && results[0] is bool b) + return b; + + return luaCode.ScriptMethodInjectItem.ContinueAfterInject; + } +} + +public record struct ScriptManifestDict(string Code, ScriptMethodInjectItem ScriptMethodInjectItem); + +public record struct ScriptEntry( + [property: JsonPropertyName("method")] ScriptMethodInjectItem? Method, + [property: JsonPropertyName("type_initializer")] ScriptMethodInjectItem? TypeInitializer + ); + +public record struct ScriptMethodInjectItem( + [property: JsonPropertyName("method")] ScriptMethodInfo Method, + [property: JsonPropertyName("continue")] bool ContinueAfterInject, + [property: JsonPropertyName("script")] LuaMethodEntry Script + ); + +public record struct ScriptMethodInfo( + [property: JsonPropertyName("class")] string Class, + [property: JsonPropertyName("method")] string Method + ); + +public record struct LuaMethodEntry( + [property: JsonPropertyName("lua_file")] string LuaFile + ); \ No newline at end of file diff --git a/Nebula.Shared/FileApis/FileApi.cs b/Nebula.Shared/FileApis/FileApi.cs index c257c03..f426477 100644 --- a/Nebula.Shared/FileApis/FileApi.cs +++ b/Nebula.Shared/FileApis/FileApi.cs @@ -75,6 +75,11 @@ public sealed class FileApi : IReadWriteFileApi return File.Exists(fullPath); } + public IEnumerable EnumerateDirectories(string path) + { + return Directory.GetDirectories(Path.Join(RootPath, path)).Select(p=>p.Replace(RootPath,"").Substring(1)); + } + private IEnumerable GetAllFiles(){ if(!Directory.Exists(RootPath)) return []; diff --git a/Nebula.Shared/Services/AssemblyService.cs b/Nebula.Shared/Services/AssemblyService.cs index db534e9..ae1a24e 100644 --- a/Nebula.Shared/Services/AssemblyService.cs +++ b/Nebula.Shared/Services/AssemblyService.cs @@ -12,15 +12,23 @@ namespace Nebula.Shared.Services; [ServiceRegister] public class AssemblyService { - private readonly List _assemblies = new(); + private readonly Dictionary _assemblyCache = new(); private readonly ILogger _logger; private readonly HashSet _resolvingAssemblies = new(); + private List _mountedApis = []; + + public Action? OnAssemblyLoaded; + public IReadOnlyList Assemblies => _assemblyCache.Values.ToList().AsReadOnly(); + public AssemblyService(DebugService debugService) { _logger = debugService.GetLogger(this); + AssemblyLoadContext.Default.ResolvingUnmanagedDll += LoadContextOnResolvingUnmanaged; + AssemblyLoadContext.Default.Resolving += (context, name) => OnAssemblyResolving(context, name); + ZstdImportResolver.ResolveLibrary += (name, assembly1, path) => { if (name.Equals("SharpZstd.Native")) @@ -36,21 +44,41 @@ public class AssemblyService }; } - public IReadOnlyList Assemblies => _assemblies; + private Assembly? OnAssemblyResolving(AssemblyLoadContext context, AssemblyName name) + { + if (_resolvingAssemblies.Contains(name.FullName)) + { + _logger.Debug($"Already resolving {name.Name}, skipping."); + return null; // Prevent recursive resolution + } + + Assembly? assembly; + + if (_assemblyCache.TryGetValue(name.Name ?? "", out assembly)) + { + return assembly; + } + + foreach (var api in _mountedApis) + { + if((assembly = OnAssemblyResolving(context, name, api)) != null) + return assembly; + } + + return null; + } public AssemblyApi Mount(IFileApi fileApi) { var asmApi = new AssemblyApi(fileApi); - AssemblyLoadContext.Default.Resolving += (context, name) => OnAssemblyResolving(context, name, asmApi); - AssemblyLoadContext.Default.ResolvingUnmanagedDll += LoadContextOnResolvingUnmanaged; - + _mountedApis.Add(asmApi); return asmApi; } public bool TryGetLoader(Assembly clientAssembly, [NotNullWhen(true)] out ILoaderEntryPoint? loader) { loader = null; - // Find ILoaderEntryPoint with the LoaderEntryPointAttribute + var attrib = clientAssembly.GetCustomAttribute(); if (attrib == null) { @@ -79,9 +107,11 @@ public class AssemblyService assembly = AssemblyLoadContext.Default.LoadFromStream(asm, pdb); _logger.Log("LOADED ASSEMBLY " + name); - - - if (!_assemblies.Contains(assembly)) _assemblies.Add(assembly); + + if (_assemblyCache.TryAdd(name, assembly)) + { + OnAssemblyLoaded?.Invoke(assembly); + } asm.Dispose(); pdb?.Dispose(); @@ -103,21 +133,18 @@ public class AssemblyService private Assembly? OnAssemblyResolving(AssemblyLoadContext context, AssemblyName name, AssemblyApi assemblyApi) { - if (_resolvingAssemblies.Contains(name.FullName)) + lock (_resolvingAssemblies) { - _logger.Debug($"Already resolving {name.Name}, skipping."); - return null; // Prevent recursive resolution - } - - try - { - _resolvingAssemblies.Add(name.FullName); - _logger.Debug($"Resolving assembly from FileAPI: {name.Name}"); - return TryOpenAssembly(name.Name!, assemblyApi, out var assembly) ? assembly : null; - } - finally - { - _resolvingAssemblies.Remove(name.FullName); + try + { + _resolvingAssemblies.Add(name.FullName); + _logger.Debug($"Resolving assembly from FileAPI: {name.Name}"); + return TryOpenAssembly(name.Name!, assemblyApi, out var assembly) ? assembly : null; + } + finally + { + _resolvingAssemblies.Remove(name.FullName); + } } } diff --git a/Nebula.Shared/Services/FileService.cs b/Nebula.Shared/Services/FileService.cs index 10d95d4..0bdc555 100644 --- a/Nebula.Shared/Services/FileService.cs +++ b/Nebula.Shared/Services/FileService.cs @@ -23,7 +23,7 @@ public class FileService Directory.CreateDirectory(RootPath); } - public IReadWriteFileApi CreateFileApi(string path) + public FileApi CreateFileApi(string path) { _logger.Debug($"Creating file api for {path}"); return new FileApi(Path.Join(RootPath, path)); diff --git a/Nebula.sln.DotSettings.user b/Nebula.sln.DotSettings.user index 114714c..a431ec7 100644 --- a/Nebula.sln.DotSettings.user +++ b/Nebula.sln.DotSettings.user @@ -1,5 +1,9 @@  ForceIncluded + ForceIncluded + ForceIncluded + ForceIncluded + ForceIncluded ForceIncluded ForceIncluded ForceIncluded @@ -38,6 +42,7 @@ ForceIncluded ForceIncluded ForceIncluded + ForceIncluded ForceIncluded ForceIncluded ForceIncluded