- tweak: refactor funny

This commit is contained in:
2025-01-05 17:05:23 +03:00
parent 5b24f915a2
commit 8619e248fd
67 changed files with 485 additions and 492 deletions

View File

@@ -0,0 +1,25 @@
using Nebula.Shared.Services;
namespace Nebula.Shared;
public static class CurrentConVar
{
public static readonly ConVar<string> EngineManifestUrl =
ConVarBuilder.Build("engine.manifestUrl", "https://robust-builds.cdn.spacestation14.com/manifest.json");
public static readonly ConVar<string> EngineModuleManifestUrl =
ConVarBuilder.Build("engine.moduleManifestUrl", "https://robust-builds.cdn.spacestation14.com/modules.json");
public static readonly ConVar<int> ManifestDownloadProtocolVersion =
ConVarBuilder.Build("engine.manifestDownloadProtocolVersion", 1);
public static readonly ConVar<string> RobustAssemblyName =
ConVarBuilder.Build("engine.robustAssemblyName", "Robust.Client");
public static readonly ConVar<string[]> Hub = ConVarBuilder.Build<string[]>("launcher.hub", [
"https://hub.spacestation14.com/api/servers"
]);
public static readonly ConVar<string[]> AuthServers = ConVarBuilder.Build<string[]>("launcher.authServers", [
"https://auth.spacestation14.com/"
]);
public static readonly ConVar<AuthLoginPassword[]> AuthProfiles = ConVarBuilder.Build<AuthLoginPassword[]>("auth.profiles", []);
public static readonly ConVar<AuthLoginPassword> AuthCurrent = ConVarBuilder.Build<AuthLoginPassword>("auth.current");
}

View File

@@ -0,0 +1,20 @@
using Robust.LoaderApi;
namespace Nebula.Shared.FileApis;
public class AssemblyApi : IFileApi
{
private readonly IFileApi _root;
public AssemblyApi(IFileApi root)
{
_root = root;
}
public bool TryOpen(string path, out Stream? stream)
{
return _root.TryOpen(path, out stream);
}
public IEnumerable<string> AllFiles => _root.AllFiles;
}

View File

@@ -0,0 +1,54 @@
using Nebula.Shared.FileApis.Interfaces;
namespace Nebula.Shared.FileApis;
public class FileApi : IReadWriteFileApi
{
public string RootPath;
public FileApi(string rootPath)
{
RootPath = rootPath;
}
public bool TryOpen(string path, out Stream? stream)
{
if (File.Exists(Path.Join(RootPath, path)))
{
stream = File.OpenRead(Path.Join(RootPath, path));
return true;
}
stream = null;
return false;
}
public bool Save(string path, Stream input)
{
var currPath = Path.Join(RootPath, path);
var dirInfo = new DirectoryInfo(Path.GetDirectoryName(currPath));
if (!dirInfo.Exists) dirInfo.Create();
using var stream = File.OpenWrite(currPath);
input.CopyTo(stream);
stream.Flush();
stream.Close();
return true;
}
public bool Remove(string path)
{
if (!Has(path)) return false;
File.Delete(Path.Join(RootPath, path));
return true;
}
public bool Has(string path)
{
var currPath = Path.Join(RootPath, path);
return File.Exists(currPath);
}
public IEnumerable<string> AllFiles => Directory.EnumerateFiles(RootPath, "*.*", SearchOption.AllDirectories);
}

View File

@@ -0,0 +1,30 @@
using Nebula.Shared.Models;
using Robust.LoaderApi;
namespace Nebula.Shared.FileApis;
public class HashApi : IFileApi
{
private readonly IFileApi _fileApi;
public Dictionary<string, RobustManifestItem> Manifest;
public HashApi(List<RobustManifestItem> manifest, IFileApi fileApi)
{
_fileApi = fileApi;
Manifest = new Dictionary<string, RobustManifestItem>();
foreach (var item in manifest) Manifest.TryAdd(item.Path, item);
}
public bool TryOpen(string path, out Stream? stream)
{
if (path[0] == '/') path = path.Substring(1);
if (Manifest.TryGetValue(path, out var a) && _fileApi.TryOpen(a.Hash, out stream))
return true;
stream = null;
return false;
}
public IEnumerable<string> AllFiles => Manifest.Keys;
}

View File

@@ -0,0 +1,7 @@
using Robust.LoaderApi;
namespace Nebula.Shared.FileApis.Interfaces;
public interface IReadWriteFileApi : IFileApi, IWriteFileApi
{
}

View File

@@ -0,0 +1,8 @@
namespace Nebula.Shared.FileApis.Interfaces;
public interface IWriteFileApi
{
public bool Save(string path, Stream input);
public bool Remove(string path);
public bool Has(string path);
}

View File

@@ -0,0 +1,61 @@
using System.Diagnostics.CodeAnalysis;
using System.IO.Compression;
using System.Runtime.InteropServices;
using Robust.LoaderApi;
namespace Nebula.Shared.FileApis;
public sealed class ZipFileApi : IFileApi
{
private readonly ZipArchive _archive;
private readonly string? _prefix;
public ZipFileApi(ZipArchive archive, string? prefix)
{
_archive = archive;
_prefix = prefix;
}
public bool TryOpen(string path, [NotNullWhen(true)] out Stream? stream)
{
var entry = _archive.GetEntry(_prefix != null ? _prefix + path : path);
if (entry == null)
{
stream = null;
return false;
}
stream = new MemoryStream();
lock (_archive)
{
using var zipStream = entry.Open();
zipStream.CopyTo(stream);
}
stream.Position = 0;
return true;
}
public IEnumerable<string> AllFiles
{
get
{
if (_prefix != null)
return _archive.Entries
.Where(e => e.Name != "" && e.FullName.StartsWith(_prefix))
.Select(e => e.FullName[_prefix.Length..]);
return _archive.Entries
.Where(e => e.Name != "")
.Select(entry => entry.FullName);
}
}
public static ZipFileApi FromPath(string path)
{
var zipArchive = new ZipArchive(File.OpenRead(path), ZipArchiveMode.Read);
var prefix = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) prefix = "Space Station 14.app/Contents/Resources/";
return new ZipFileApi(zipArchive, prefix);
}
}

View File

@@ -0,0 +1,8 @@
namespace Nebula.Shared.Models.Auth;
public sealed record AuthenticateRequest(string? Username, Guid? UserId, string Password, string? TfaCode = null)
{
public AuthenticateRequest(string username, string password) : this(username, null, password)
{
}
}

View File

@@ -0,0 +1,3 @@
namespace Nebula.Shared.Models.Auth;
public sealed record AuthenticateResponse(string Token, string Username, Guid UserId, DateTimeOffset ExpireTime);

View File

@@ -0,0 +1,13 @@
namespace Nebula.Shared.Models.Auth;
public class LoginInfo
{
public Guid UserId { get; set; }
public string Username { get; set; } = default!;
public LoginToken Token { get; set; }
public override string ToString()
{
return $"{Username}/{UserId}";
}
}

View File

@@ -0,0 +1,13 @@
namespace Nebula.Shared.Models.Auth;
public readonly struct LoginToken
{
public readonly string Token;
public readonly DateTimeOffset ExpireTime;
public LoginToken(string token, DateTimeOffset expireTime)
{
Token = token;
ExpireTime = expireTime;
}
}

View File

@@ -0,0 +1,12 @@
namespace Nebula.Shared.Models;
public enum ContentCompressionScheme
{
None = 0,
Deflate = 1,
/// <summary>
/// ZStandard compression. In the future may use SS14 specific dictionary IDs in the frame header.
/// </summary>
ZStd = 2
}

View File

@@ -0,0 +1,13 @@
namespace Nebula.Shared.Models;
[Flags]
public enum DownloadStreamHeaderFlags
{
None = 0,
/// <summary>
/// If this flag is set on the download stream, individual files have been pre-compressed by the server.
/// This means each file has a compression header, and the launcher should not attempt to compress files itself.
/// </summary>
PreCompressed = 1 << 0
}

View File

@@ -0,0 +1,2 @@
namespace Nebula.Shared.Models;
public record ListItemTemplate(Type ModelType, string IconKey, string Label);

View File

@@ -0,0 +1,19 @@
using Robust.LoaderApi;
namespace Nebula.Shared.Models;
public sealed class MainArgs : IMainArgs
{
public MainArgs(string[] args, IFileApi fileApi, IRedialApi? redialApi, IEnumerable<ApiMount>? apiMounts)
{
Args = args;
FileApi = fileApi;
RedialApi = redialApi;
ApiMounts = apiMounts;
}
public string[] Args { get; }
public IFileApi FileApi { get; }
public IRedialApi? RedialApi { get; }
public IEnumerable<ApiMount>? ApiMounts { get; }
}

View File

@@ -0,0 +1,8 @@
namespace Nebula.Shared.Models;
public class RobustBuildInfo
{
public ServerInfo BuildInfo;
public RobustManifestInfo RobustManifestInfo;
public RobustUrl Url;
}

View File

@@ -0,0 +1,3 @@
namespace Nebula.Shared.Models;
public record struct RobustManifestInfo(Uri ManifestUri, Uri DownloadUri, string Hash);

View File

@@ -0,0 +1,3 @@
namespace Nebula.Shared.Models;
public record struct RobustManifestItem(string Hash, string Path, int Id);

View File

@@ -0,0 +1,73 @@
using System.Text.Json.Serialization;
namespace Nebula.Shared.Models;
public sealed record AuthInfo(
[property: JsonPropertyName("mode")] string Mode,
[property: JsonPropertyName("public_key")] string PublicKey);
public sealed record BuildInfo(
[property: JsonPropertyName("engine_version")] string EngineVersion,
[property: JsonPropertyName("fork_id")] string ForkId,
[property: JsonPropertyName("version")] string Version,
[property: JsonPropertyName("download_url")] string DownloadUrl,
[property: JsonPropertyName("manifest_download_url")] string ManifestDownloadUrl,
[property: JsonPropertyName("manifest_url")] string ManifestUrl,
[property: JsonPropertyName("acz")] bool Acz,
[property: JsonPropertyName("hash")] string Hash,
[property: JsonPropertyName("manifest_hash")] string ManifestHash);
public sealed record ServerLink(
[property: JsonPropertyName("name")] string Name,
[property: JsonPropertyName("icon")] string Icon,
[property: JsonPropertyName("url")] string Url);
public sealed record ServerInfo(
[property: JsonPropertyName("connect_address")] string ConnectAddress,
[property: JsonPropertyName("auth")] AuthInfo Auth,
[property: JsonPropertyName("build")] BuildInfo Build,
[property: JsonPropertyName("desc")] string Desc,
[property: JsonPropertyName("links")] List<ServerLink> Links);
public sealed record EngineVersionInfo(
[property: JsonPropertyName("insecure")] bool Insecure,
[property: JsonPropertyName("redirect")] string? RedirectVersion,
[property: JsonPropertyName("platforms")] Dictionary<string, EngineBuildInfo> Platforms);
public sealed class EngineBuildInfo
{
[JsonInclude] [JsonPropertyName("sha256")]
public string Sha256 = default!;
[JsonInclude] [JsonPropertyName("sig")]
public string Signature = default!;
[JsonInclude] [JsonPropertyName("url")]
public string Url = default!;
}
public sealed record ServerHubInfo(
[property: JsonPropertyName("address")] string Address,
[property: JsonPropertyName("statusData")] ServerStatus StatusData,
[property: JsonPropertyName("inferredTags")] List<string> InferredTags);
public sealed record ServerStatus(
[property: JsonPropertyName("map")] string Map,
[property: JsonPropertyName("name")] string Name,
[property: JsonPropertyName("tags")] List<string> Tags,
[property: JsonPropertyName("preset")] string Preset,
[property: JsonPropertyName("players")] int Players,
[property: JsonPropertyName("round_id")] int RoundId,
[property: JsonPropertyName("run_level")] int RunLevel,
[property: JsonPropertyName("panic_bunker")] bool PanicBunker,
[property: JsonPropertyName("round_start_time")] DateTime? RoundStartTime,
[property: JsonPropertyName("soft_max_players")] int SoftMaxPlayers);
public sealed record ModulesInfo(
[property: JsonPropertyName("modules")] Dictionary<string, Module> Modules);
public sealed record Module(
[property: JsonPropertyName("versions")] Dictionary<string, ModuleVersionInfo> Versions);
public sealed record ModuleVersionInfo(
[property: JsonPropertyName("platforms")] Dictionary<string, EngineBuildInfo> Platforms);

View File

@@ -0,0 +1,63 @@
using Nebula.Shared.Utils;
namespace Nebula.Shared.Models;
public class RobustUrl
{
public RobustUrl(string url)
{
if (!UriHelper.TryParseSs14Uri(url, out var uri))
throw new Exception("Invalid scheme");
Uri = uri;
HttpUri = UriHelper.GetServerApiAddress(Uri);
}
public Uri Uri { get; }
public Uri HttpUri { get; }
public RobustPath InfoUri => new(this, "info");
public RobustPath StatusUri => new(this, "status");
public override string ToString()
{
return Uri.ToString();
}
public static implicit operator Uri(RobustUrl url)
{
return url.HttpUri;
}
public static explicit operator RobustUrl(string url)
{
return new RobustUrl(url);
}
public static explicit operator RobustUrl(Uri uri)
{
return new RobustUrl(uri.ToString());
}
}
public class RobustPath
{
public string Path;
public RobustUrl Url;
public RobustPath(RobustUrl url, string path)
{
Url = url;
Path = path;
}
public override string ToString()
{
return ((Uri)this).ToString();
}
public static implicit operator Uri(RobustPath path)
{
return new Uri(path.Url, path.Url.HttpUri.PathAndQuery + path.Path);
}
}

View File

@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<EmbeddedResource Include="Utils\runtime.json" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="9.0.0" />
<PackageReference Include="libsodium" Version="1.0.20" />
<PackageReference Include="Robust.Natives" Version="0.1.1" />
<PackageReference Include="SharpZstd.Interop" Version="1.5.6" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Robust.LoaderApi\Robust.LoaderApi\Robust.LoaderApi.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,44 @@
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
namespace Nebula.Shared;
public static class ServiceExt
{
public static void AddServices(this IServiceCollection services)
{
foreach (var (type, inference) in GetServicesWithHelpAttribute(Assembly.GetExecutingAssembly()))
{
if (inference is null)
{
services.AddSingleton(type);
}
else
{
services.AddSingleton(inference, type);
}
}
}
private static IEnumerable<(Type,Type?)> GetServicesWithHelpAttribute(Assembly assembly) {
foreach(Type type in assembly.GetTypes())
{
var attr = type.GetCustomAttribute<ServiceRegisterAttribute>();
if (attr is not null) {
yield return (type, attr.Inference);
}
}
}
}
public sealed class ServiceRegisterAttribute : Attribute
{
public Type? Inference { get; }
public bool IsSingleton { get; }
public ServiceRegisterAttribute(Type? inference = null, bool isSingleton = true)
{
IsSingleton = isSingleton;
Inference = inference;
}
}

View File

@@ -0,0 +1,110 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;
using Nebula.Shared.FileApis;
using Robust.LoaderApi;
namespace Nebula.Shared.Services;
[ServiceRegister]
public class AssemblyService
{
private readonly Dictionary<string,Assembly> _assemblies = new();
private readonly DebugService _debugService;
public AssemblyService(DebugService debugService)
{
_debugService = debugService;
}
//public IReadOnlyList<Assembly> Assemblies => _assemblies;
public AssemblyApi Mount(IFileApi fileApi, string apiName = "")
{
var asmApi = new AssemblyApi(fileApi);
AssemblyLoadContext.Default.Resolving += (context, name) => OnAssemblyResolving(context, name, asmApi, apiName);
AssemblyLoadContext.Default.ResolvingUnmanagedDll += LoadContextOnResolvingUnmanaged;
return asmApi;
}
public bool TryGetLoader(Assembly clientAssembly, [NotNullWhen(true)] out ILoaderEntryPoint? loader)
{
loader = null;
// Find ILoaderEntryPoint with the LoaderEntryPointAttribute
var attrib = clientAssembly.GetCustomAttribute<LoaderEntryPointAttribute>();
if (attrib == null)
{
Console.WriteLine("No LoaderEntryPointAttribute found on Robust.Client assembly!");
return false;
}
var type = attrib.LoaderEntryPointType;
if (!type.IsAssignableTo(typeof(ILoaderEntryPoint)))
{
Console.WriteLine("Loader type '{0}' does not implement ILoaderEntryPoint!", type);
return false;
}
loader = (ILoaderEntryPoint)Activator.CreateInstance(type)!;
return true;
}
public bool TryOpenAssembly(string name, AssemblyApi assemblyApi, [NotNullWhen(true)] out Assembly? assembly)
{
if (_assemblies.TryGetValue(name, out assembly))
{
return true;
}
if (!TryOpenAssemblyStream(name, assemblyApi, out var asm, out var pdb))
{
assembly = null;
return false;
}
assembly = AssemblyLoadContext.Default.LoadFromStream(asm, pdb);
_debugService.Log("LOADED ASSEMBLY " + name);
_assemblies.Add(name, assembly);
asm.Dispose();
pdb?.Dispose();
return true;
}
public bool TryOpenAssemblyStream(string name, AssemblyApi assemblyApi, [NotNullWhen(true)] out Stream? asm,
out Stream? pdb)
{
asm = null;
pdb = null;
if (!assemblyApi.TryOpen($"{name}.dll", out asm))
return false;
assemblyApi.TryOpen($"{name}.pdb", out pdb);
return true;
}
private Assembly? OnAssemblyResolving(AssemblyLoadContext context, AssemblyName name, AssemblyApi assemblyApi,
string apiName)
{
_debugService.Debug($"Resolving assembly from {apiName}: {name.Name}");
return TryOpenAssembly(name.Name!, assemblyApi, out var assembly) ? assembly : null;
}
private IntPtr LoadContextOnResolvingUnmanaged(Assembly assembly, string unmanaged)
{
var ourDir = Path.GetDirectoryName(typeof(AssemblyApi).Assembly.Location);
var a = Path.Combine(ourDir!, unmanaged);
_debugService.Debug($"Loading dll lib: {a}");
if (NativeLibrary.TryLoad(a, out var handle))
return handle;
return IntPtr.Zero;
}
}

View File

@@ -0,0 +1,65 @@
using System.Net.Http.Headers;
using Nebula.Shared.Models.Auth;
namespace Nebula.Shared.Services;
[ServiceRegister]
public partial class AuthService(
RestService restService,
DebugService debugService,
CancellationService cancellationService)
{
private readonly HttpClient _httpClient = new();
public CurrentAuthInfo? SelectedAuth { get; internal set; }
public string Reason = "";
public async Task<bool> Auth(AuthLoginPassword authLoginPassword)
{
var authServer = authLoginPassword.AuthServer;
var login = authLoginPassword.Login;
var password = authLoginPassword.Password;
debugService.Debug($"Auth to {authServer}api/auth/authenticate {login}");
var authUrl = new Uri($"{authServer}api/auth/authenticate");
var result =
await restService.PostAsync<AuthenticateResponse, AuthenticateRequest>(
new AuthenticateRequest(login, password), authUrl, cancellationService.Token);
if (result.Value is null)
{
Reason = result.Message;
return false;
}
SelectedAuth = new CurrentAuthInfo(result.Value.UserId,
new LoginToken(result.Value.Token, result.Value.ExpireTime), authLoginPassword);
return true;
}
public void ClearAuth()
{
SelectedAuth = null;
}
public async Task<bool> EnsureToken()
{
if (SelectedAuth is null) return false;
var authUrl = new Uri($"{SelectedAuth.AuthLoginPassword.AuthServer}api/auth/ping");
using var requestMessage = new HttpRequestMessage(HttpMethod.Get, authUrl);
requestMessage.Headers.Authorization = new AuthenticationHeaderValue("SS14Auth", SelectedAuth.Token.Token);
using var resp = await _httpClient.SendAsync(requestMessage, cancellationService.Token);
if (!resp.IsSuccessStatusCode) SelectedAuth = null;
return resp.IsSuccessStatusCode;
}
}
public sealed record CurrentAuthInfo(Guid UserId, LoginToken Token, AuthLoginPassword AuthLoginPassword);
public record AuthLoginPassword(string Login, string Password, string AuthServer);

View File

@@ -0,0 +1,14 @@
namespace Nebula.Shared.Services;
[ServiceRegister]
public class CancellationService
{
private CancellationTokenSource _cancellationTokenSource = new();
public CancellationToken Token => _cancellationTokenSource.Token;
public void Cancel()
{
_cancellationTokenSource.Cancel();
_cancellationTokenSource = new CancellationTokenSource();
}
}

View File

@@ -0,0 +1,116 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
namespace Nebula.Shared.Services;
public class ConVar<T>
{
public string Name { get; }
public Type Type => typeof(T);
public T? DefaultValue { get; }
public ConVar(string name, T? defaultValue = default)
{
Name = name ?? throw new ArgumentNullException(nameof(name));
DefaultValue = defaultValue;
}
}
public static class ConVarBuilder
{
public static ConVar<T> Build<T>(string name, T? defaultValue = default)
{
if (string.IsNullOrWhiteSpace(name))
throw new ArgumentException("ConVar name cannot be null or whitespace.", nameof(name));
return new ConVar<T>(name, defaultValue);
}
}
[ServiceRegister]
public class ConfigurationService
{
private readonly FileService _fileService;
private readonly DebugService _debugService;
public ConfigurationService(FileService fileService, DebugService debugService)
{
_fileService = fileService ?? throw new ArgumentNullException(nameof(fileService));
_debugService = debugService ?? throw new ArgumentNullException(nameof(debugService));
}
public T? GetConfigValue<T>(ConVar<T> conVar)
{
ArgumentNullException.ThrowIfNull(conVar);
try
{
if (_fileService.ConfigurationApi.TryOpen(GetFileName(conVar), out var stream))
{
using (stream)
{
var obj = JsonSerializer.Deserialize<T>(stream);
if (obj != null)
{
_debugService.Log($"Successfully loaded config: {conVar.Name}");
return obj;
}
}
}
}
catch (Exception e)
{
_debugService.Error($"Error loading config for {conVar.Name}: {e.Message}");
}
_debugService.Log($"Using default value for config: {conVar.Name}");
return conVar.DefaultValue;
}
public void SetConfigValue<T>(ConVar<T> conVar, T value)
{
ArgumentNullException.ThrowIfNull(conVar);
if (value == null) throw new ArgumentNullException(nameof(value));
if (!conVar.Type.IsInstanceOfType(value))
{
_debugService.Error($"Type mismatch for config {conVar.Name}. Expected {conVar.Type}, got {value.GetType()}.");
return;
}
try
{
_debugService.Log($"Saving config: {conVar.Name}");
var serializedData = JsonSerializer.Serialize(value);
using var stream = new MemoryStream();
using var writer = new StreamWriter(stream);
writer.Write(serializedData);
writer.Flush();
stream.Seek(0, SeekOrigin.Begin);
_fileService.ConfigurationApi.Save(GetFileName(conVar), stream);
}
catch (Exception e)
{
_debugService.Error($"Error saving config for {conVar.Name}: {e.Message}");
}
}
private static string GetFileName<T>(ConVar<T> conVar)
{
return $"{conVar.Name}.json";
}
}
public static class ConfigExtensions
{
public static bool TryGetConfigValue<T>(this ConfigurationService configurationService, ConVar<T> conVar, [NotNullWhen(true)] out T? value)
{
ArgumentNullException.ThrowIfNull(configurationService);
ArgumentNullException.ThrowIfNull(conVar);
value = configurationService.GetConfigValue(conVar);
return value != null;
}
}

View File

@@ -0,0 +1,253 @@
using System.Buffers.Binary;
using System.Globalization;
using System.Net.Http.Headers;
using System.Numerics;
using Nebula.Shared.FileApis.Interfaces;
using Nebula.Shared.Models;
using Nebula.Shared.Utils;
namespace Nebula.Shared.Services;
public partial class ContentService
{
public bool CheckManifestExist(RobustManifestItem item)
{
return fileService.ContentFileApi.Has(item.Hash);
}
public async Task<List<RobustManifestItem>> EnsureItems(ManifestReader manifestReader, Uri downloadUri,
CancellationToken cancellationToken)
{
List<RobustManifestItem> allItems = [];
List<RobustManifestItem> items = [];
while (manifestReader.TryReadItem(out var item))
{
if (cancellationToken.IsCancellationRequested)
{
debugService.Log("ensuring is cancelled!");
return [];
}
if (!CheckManifestExist(item.Value))
items.Add(item.Value);
allItems.Add(item.Value);
}
debugService.Log("Download Count:" + items.Count);
await Download(downloadUri, items, cancellationToken);
fileService.ManifestItems = allItems;
return allItems;
}
public async Task<List<RobustManifestItem>> EnsureItems(RobustManifestInfo info,
CancellationToken cancellationToken)
{
debugService.Log("Getting manifest: " + info.Hash);
if (fileService.ManifestFileApi.TryOpen(info.Hash, out var stream))
{
debugService.Log("Loading manifest from: " + info.Hash);
return await EnsureItems(new ManifestReader(stream), info.DownloadUri, cancellationToken);
}
debugService.Log("Fetching manifest from: " + info.ManifestUri);
var response = await _http.GetAsync(info.ManifestUri, cancellationToken);
if (!response.IsSuccessStatusCode) throw new Exception();
await using var streamContent = await response.Content.ReadAsStreamAsync(cancellationToken);
fileService.ManifestFileApi.Save(info.Hash, streamContent);
streamContent.Seek(0, SeekOrigin.Begin);
using var manifestReader = new ManifestReader(streamContent);
return await EnsureItems(manifestReader, info.DownloadUri, cancellationToken);
}
public async Task Unpack(RobustManifestInfo info, IWriteFileApi otherApi, CancellationToken cancellationToken)
{
debugService.Log("Unpack manifest files");
var items = await EnsureItems(info, cancellationToken);
foreach (var item in items)
if (fileService.ContentFileApi.TryOpen(item.Hash, out var stream))
{
debugService.Log($"Unpack {item.Hash} to: {item.Path}");
otherApi.Save(item.Path, stream);
stream.Close();
}
else
{
debugService.Error("OH FUCK!! " + item.Path);
}
}
public async Task Download(Uri contentCdn, List<RobustManifestItem> toDownload, CancellationToken cancellationToken)
{
if (toDownload.Count == 0 || cancellationToken.IsCancellationRequested)
{
debugService.Log("Nothing to download! Fuck this!");
return;
}
debugService.Log("Downloading from: " + contentCdn);
var requestBody = new byte[toDownload.Count * 4];
var reqI = 0;
foreach (var item in toDownload)
{
BinaryPrimitives.WriteInt32LittleEndian(requestBody.AsSpan(reqI, 4), item.Id);
reqI += 4;
}
var request = new HttpRequestMessage(HttpMethod.Post, contentCdn);
request.Headers.Add(
"X-Robust-Download-Protocol",
varService.GetConfigValue(CurrentConVar.ManifestDownloadProtocolVersion).ToString(CultureInfo.InvariantCulture));
request.Content = new ByteArrayContent(requestBody);
request.Content.Headers.ContentType = new MediaTypeHeaderValue("application/octet-stream");
request.Headers.AcceptEncoding.Add(new StringWithQualityHeaderValue("zstd"));
var response = await _http.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
if (cancellationToken.IsCancellationRequested)
{
debugService.Log("Downloading is cancelled!");
return;
}
response.EnsureSuccessStatusCode();
var stream = await response.Content.ReadAsStreamAsync();
var bandwidthStream = new BandwidthStream(stream);
stream = bandwidthStream;
if (response.Content.Headers.ContentEncoding.Contains("zstd"))
stream = new ZStdDecompressStream(stream);
await using var streamDispose = stream;
// Read flags header
var streamHeader = await stream.ReadExactAsync(4, null);
var streamFlags = (DownloadStreamHeaderFlags)BinaryPrimitives.ReadInt32LittleEndian(streamHeader);
var preCompressed = (streamFlags & DownloadStreamHeaderFlags.PreCompressed) != 0;
// compressContext.SetParameter(ZSTD_cParameter.ZSTD_c_nbWorkers, 4);
// If the stream is pre-compressed we need to decompress the blobs to verify BLAKE2B hash.
// If it isn't, we need to manually try re-compressing individual files to store them.
var compressContext = preCompressed ? null : new ZStdCCtx();
var decompressContext = preCompressed ? new ZStdDCtx() : null;
// Normal file header:
// <int32> uncompressed length
// When preCompressed is set, we add:
// <int32> compressed length
var fileHeader = new byte[preCompressed ? 8 : 4];
try
{
// Buffer for storing compressed ZStd data.
var compressBuffer = new byte[1024];
// Buffer for storing uncompressed data.
var readBuffer = new byte[1024];
var i = 0;
foreach (var item in toDownload)
{
if (cancellationToken.IsCancellationRequested)
{
debugService.Log("Downloading is cancelled!");
decompressContext?.Dispose();
compressContext?.Dispose();
return;
}
// Read file header.
await stream.ReadExactAsync(fileHeader, null);
var length = BinaryPrimitives.ReadInt32LittleEndian(fileHeader.AsSpan(0, 4));
EnsureBuffer(ref readBuffer, length);
var data = readBuffer.AsMemory(0, length);
// Data to write to database.
var compression = ContentCompressionScheme.None;
var writeData = data;
if (preCompressed)
{
// Compressed length from extended header.
var compressedLength = BinaryPrimitives.ReadInt32LittleEndian(fileHeader.AsSpan(4, 4));
if (compressedLength > 0)
{
EnsureBuffer(ref compressBuffer, compressedLength);
var compressedData = compressBuffer.AsMemory(0, compressedLength);
await stream.ReadExactAsync(compressedData, null);
// Decompress so that we can verify hash down below.
var decompressedLength = decompressContext!.Decompress(data.Span, compressedData.Span);
if (decompressedLength != data.Length)
throw new Exception($"Compressed blob {i} had incorrect decompressed size!");
// Set variables so that the database write down below uses them.
compression = ContentCompressionScheme.ZStd;
writeData = compressedData;
}
else
{
await stream.ReadExactAsync(data, null);
}
}
else
{
await stream.ReadExactAsync(data, null);
}
if (!preCompressed)
{
// File wasn't pre-compressed. We should try to manually compress it to save space in DB.
EnsureBuffer(ref compressBuffer, ZStd.CompressBound(data.Length));
var compressLength = compressContext!.Compress(compressBuffer, data.Span);
// Don't bother saving compressed data if it didn't save enough space.
if (compressLength + 10 < length)
{
// Set variables so that the database write down below uses them.
compression = ContentCompressionScheme.ZStd;
writeData = compressBuffer.AsMemory(0, compressLength);
}
}
using var fileStream = new MemoryStream(data.ToArray());
fileService.ContentFileApi.Save(item.Hash, fileStream);
debugService.Log("file saved:" + item.Path);
i += 1;
}
}
finally
{
decompressContext?.Dispose();
compressContext?.Dispose();
}
}
private static void EnsureBuffer(ref byte[] buf, int needsFit)
{
if (buf.Length >= needsFit)
return;
var newLen = 2 << BitOperations.Log2((uint)needsFit - 1);
buf = new byte[newLen];
}
}

View File

@@ -0,0 +1,30 @@
using System.Data;
using Nebula.Shared.Models;
namespace Nebula.Shared.Services;
[ServiceRegister]
public partial class ContentService(
RestService restService,
DebugService debugService,
ConfigurationService varService,
FileService fileService)
{
private readonly HttpClient _http = new();
public async Task<RobustBuildInfo> GetBuildInfo(RobustUrl url, CancellationToken cancellationToken)
{
var info = new RobustBuildInfo();
info.Url = url;
var bi = await restService.GetAsync<ServerInfo>(url.InfoUri, cancellationToken);
if (bi.Value is null) throw new NoNullAllowedException();
info.BuildInfo = bi.Value;
info.RobustManifestInfo = info.BuildInfo.Build.Acz
? new RobustManifestInfo(new RobustPath(info.Url, "manifest.txt"), new RobustPath(info.Url, "download"),
bi.Value.Build.ManifestHash)
: new RobustManifestInfo(new Uri(info.BuildInfo.Build.ManifestUrl),
new Uri(info.BuildInfo.Build.ManifestDownloadUrl), bi.Value.Build.ManifestHash);
return info;
}
}

View File

@@ -0,0 +1,75 @@
using Nebula.Shared.Services.Logging;
namespace Nebula.Shared.Services;
[ServiceRegister]
public class DebugService : IDisposable
{
public ILogger Logger;
private static string LogPath = Path.Combine(FileService.RootPath, "log");
public DateTime LogDate = DateTime.Now;
private FileStream LogStream;
private StreamWriter LogWriter;
public DebugService(ILogger logger)
{
Logger = logger;
//if (!Directory.Exists(LogPath))
// Directory.CreateDirectory(LogPath);
//var filename = String.Format("{0:yyyy-MM-dd}.txt", DateTime.Now);
//LogStream = File.Open(Path.Combine(LogPath, filename),
// FileMode.Append, FileAccess.Write);
//LogWriter = new StreamWriter(LogStream);
}
public void Debug(string message)
{
Log(LoggerCategory.Debug, message);
}
public void Error(string message)
{
Log(LoggerCategory.Error, message);
}
public void Log(string message)
{
Log(LoggerCategory.Log, message);
}
public void Error(Exception e)
{
Error(e.Message + "\r\n" + e.StackTrace);
if(e.InnerException != null)
Error(e.InnerException);
}
public void Dispose()
{
LogWriter.Dispose();
LogStream.Dispose();
}
private void Log(LoggerCategory category, string message)
{
Logger.Log(category, message);
//SaveToLog(category, message);
}
private void SaveToLog(LoggerCategory category, string message)
{
LogWriter.WriteLine($"[{category}] {message}");
LogWriter.Flush();
}
}
public enum LoggerCategory
{
Log,
Debug,
Error
}

View File

@@ -0,0 +1,198 @@
using System.Diagnostics.CodeAnalysis;
using Nebula.Shared.FileApis;
using Nebula.Shared.Models;
using Nebula.Shared.Utils;
namespace Nebula.Shared.Services;
[ServiceRegister]
public sealed class EngineService
{
private readonly AssemblyService _assemblyService;
private readonly DebugService _debugService;
private readonly FileService _fileService;
private readonly RestService _restService;
private readonly IServiceProvider _serviceProvider;
private readonly ConfigurationService _varService;
public Dictionary<string, Module> ModuleInfos = default!;
public Dictionary<string, EngineVersionInfo> VersionInfos = default!;
private Task _currInfoTask;
public EngineService(RestService restService, DebugService debugService, ConfigurationService varService,
FileService fileService, IServiceProvider serviceProvider, AssemblyService assemblyService)
{
_restService = restService;
_debugService = debugService;
_varService = varService;
_fileService = fileService;
_serviceProvider = serviceProvider;
_assemblyService = assemblyService;
_currInfoTask = Task.Run(() => LoadEngineManifest(CancellationToken.None));
}
public async Task LoadEngineManifest(CancellationToken cancellationToken)
{
var info = await _restService.GetAsync<Dictionary<string, EngineVersionInfo>>(
new Uri(_varService.GetConfigValue(CurrentConVar.EngineManifestUrl)!), cancellationToken);
var moduleInfo = await _restService.GetAsync<ModulesInfo>(
new Uri(_varService.GetConfigValue(CurrentConVar.EngineModuleManifestUrl)!), cancellationToken);
if (info.Value is null) return;
VersionInfos = info.Value;
if (moduleInfo.Value is null) return;
ModuleInfos = moduleInfo.Value.Modules;
foreach (var f in ModuleInfos.Keys) _debugService.Debug(f);
}
public EngineBuildInfo? GetVersionInfo(string version)
{
CheckAndWaitValidation();
if (!VersionInfos.TryGetValue(version, out var foundVersion))
return null;
if (foundVersion.RedirectVersion != null)
return GetVersionInfo(foundVersion.RedirectVersion);
var bestRid = RidUtility.FindBestRid(foundVersion.Platforms.Keys);
if (bestRid == null) bestRid = "linux-x64";
_debugService.Log("Selecting RID" + bestRid);
return foundVersion.Platforms[bestRid];
}
public bool TryGetVersionInfo(string version, [NotNullWhen(true)] out EngineBuildInfo? info)
{
info = GetVersionInfo(version);
return info != null;
}
public async Task<AssemblyApi?> EnsureEngine(string version)
{
_debugService.Log("Ensure engine " + version);
if (!TryOpen(version)) await DownloadEngine(version);
try
{
return _assemblyService.Mount(_fileService.OpenZip(version, _fileService.EngineFileApi),$"Engine.Ensure-{version}");
}
catch (Exception e)
{
_fileService.EngineFileApi.Remove(version);
throw;
}
}
public async Task DownloadEngine(string version)
{
if (!TryGetVersionInfo(version, out var info))
return;
_debugService.Log("Downloading engine version " + version);
using var client = new HttpClient();
var s = await client.GetStreamAsync(info.Url);
_fileService.EngineFileApi.Save(version, s);
await s.DisposeAsync();
}
public bool TryOpen(string version, [NotNullWhen(true)] out Stream? stream)
{
return _fileService.EngineFileApi.TryOpen(version, out stream);
}
public bool TryOpen(string version)
{
var a = TryOpen(version, out var stream);
if (a) stream!.Close();
return a;
}
public EngineBuildInfo? GetModuleBuildInfo(string moduleName, string version)
{
CheckAndWaitValidation();
if (!ModuleInfos.TryGetValue(moduleName, out var module) ||
!module.Versions.TryGetValue(version, out var value))
return null;
var bestRid = RidUtility.FindBestRid(value.Platforms.Keys);
if (bestRid == null) throw new Exception("No engine version available for our platform!");
return value.Platforms[bestRid];
}
public bool TryGetModuleBuildInfo(string moduleName, string version, [NotNullWhen(true)] out EngineBuildInfo? info)
{
info = GetModuleBuildInfo(moduleName, version);
return info != null;
}
public string ResolveModuleVersion(string moduleName, string engineVersion)
{
CheckAndWaitValidation();
var engineVersionObj = Version.Parse(engineVersion);
var module = ModuleInfos[moduleName];
var selectedVersion = module.Versions.Select(kv => new { Version = Version.Parse(kv.Key), kv.Key, kv.Value })
.Where(kv => engineVersionObj >= kv.Version)
.MaxBy(kv => kv.Version);
if (selectedVersion == null) throw new Exception();
return selectedVersion.Key;
}
public async Task<AssemblyApi?> EnsureEngineModules(string moduleName, string engineVersion)
{
var moduleVersion = ResolveModuleVersion(moduleName, engineVersion);
if (!TryGetModuleBuildInfo(moduleName, moduleVersion, out var buildInfo))
return null;
var fileName = ConcatName(moduleName, moduleVersion);
if (!TryOpen(fileName)) await DownloadEngineModule(moduleName, moduleVersion);
try
{
return _assemblyService.Mount(_fileService.OpenZip(fileName, _fileService.EngineFileApi),"Engine.EnsureModule");
}
catch (Exception e)
{
_fileService.EngineFileApi.Remove(fileName);
throw;
}
}
public async Task DownloadEngineModule(string moduleName, string moduleVersion)
{
if (!TryGetModuleBuildInfo(moduleName, moduleVersion, out var info))
return;
_debugService.Log("Downloading engine module version " + moduleVersion);
using var client = new HttpClient();
var s = await client.GetStreamAsync(info.Url);
_fileService.EngineFileApi.Save(ConcatName(moduleName, moduleVersion), s);
await s.DisposeAsync();
}
public string ConcatName(string moduleName, string moduleVersion)
{
return moduleName + "" + moduleVersion;
}
private void CheckAndWaitValidation()
{
if (_currInfoTask.IsCompleted)
return;
_debugService.Debug("thinks is not done yet, please wait");
_currInfoTask.Wait();
}
}

View File

@@ -0,0 +1,65 @@
using System.IO.Compression;
using System.Runtime.InteropServices;
using Nebula.Shared.FileApis;
using Nebula.Shared.FileApis.Interfaces;
using Nebula.Shared.Models;
using Robust.LoaderApi;
namespace Nebula.Shared.Services;
[ServiceRegister]
public class FileService
{
public static string RootPath = Path.Join(Environment.GetFolderPath(
Environment.SpecialFolder.ApplicationData), "./Datum/");
private readonly DebugService _debugService;
public readonly IReadWriteFileApi ContentFileApi;
public readonly IReadWriteFileApi EngineFileApi;
public readonly IReadWriteFileApi ManifestFileApi;
public readonly IReadWriteFileApi ConfigurationApi;
private HashApi? _hashApi;
public FileService(DebugService debugService)
{
_debugService = debugService;
ContentFileApi = CreateFileApi("content/");
EngineFileApi = CreateFileApi("engine/");
ManifestFileApi = CreateFileApi("manifest/");
ConfigurationApi = CreateFileApi("config/");
}
public List<RobustManifestItem> ManifestItems
{
set => _hashApi = new HashApi(value, ContentFileApi);
}
public HashApi HashApi
{
get
{
if (_hashApi is null) throw new Exception("Hash API is not initialized!");
return _hashApi;
}
set => _hashApi = value;
}
public IReadWriteFileApi CreateFileApi(string path)
{
return new FileApi(Path.Join(RootPath, path));
}
public ZipFileApi OpenZip(string path, IFileApi fileApi)
{
if (!fileApi.TryOpen(path, out var zipStream))
return null;
var zipArchive = new ZipArchive(zipStream, ZipArchiveMode.Read);
var prefix = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) prefix = "Space Station 14.app/Contents/Resources/";
return new ZipFileApi(zipArchive, prefix);
}
}

View File

@@ -0,0 +1,57 @@
using Nebula.Shared.Models;
namespace Nebula.Shared.Services;
[ServiceRegister]
public class HubService
{
private readonly ConfigurationService _configurationService;
private readonly RestService _restService;
public Action<HubServerChangedEventArgs>? HubServerChangedEventArgs;
private bool _isUpdating = false;
public HubService(ConfigurationService configurationService, RestService restService)
{
_configurationService = configurationService;
_restService = restService;
UpdateHub();
}
public async void UpdateHub()
{
if(_isUpdating) return;
_isUpdating = true;
HubServerChangedEventArgs?.Invoke(new HubServerChangedEventArgs([], HubServerChangeAction.Clear));
foreach (var urlStr in _configurationService.GetConfigValue(CurrentConVar.Hub)!)
{
var servers = await _restService.GetAsyncDefault<List<ServerHubInfo>>(new Uri(urlStr), [], CancellationToken.None);
HubServerChangedEventArgs?.Invoke(new HubServerChangedEventArgs(servers, HubServerChangeAction.Add));
}
_isUpdating = false;
}
}
public class HubServerChangedEventArgs : EventArgs
{
public HubServerChangeAction Action;
public List<ServerHubInfo> Items;
public HubServerChangedEventArgs(List<ServerHubInfo> items, HubServerChangeAction action)
{
Items = items;
Action = action;
}
}
public enum HubServerChangeAction
{
Add, Remove, Clear,
}

View File

@@ -0,0 +1,13 @@
namespace Nebula.Shared.Services.Logging;
[ServiceRegister(typeof(ILogger))]
public class ConsoleLogger : ILogger
{
public void Log(LoggerCategory loggerCategory, string message)
{
Console.ForegroundColor = ConsoleColor.DarkCyan;
Console.Write($"[{Enum.GetName(loggerCategory)}] ");
Console.ResetColor();
Console.WriteLine(message);
}
}

View File

@@ -0,0 +1,6 @@
namespace Nebula.Shared.Services.Logging;
public interface ILogger
{
public void Log(LoggerCategory loggerCategory, string message);
}

View File

@@ -0,0 +1,15 @@
namespace Nebula.Shared.Services;
[ServiceRegister]
public class PopupMessageService
{
public Action<object?>? OnPopupRequired;
public void Popup(object obj)
{
OnPopupRequired?.Invoke(obj);
}
public void ClosePopup()
{
OnPopupRequired?.Invoke(null);
}
}

View File

@@ -0,0 +1,149 @@
using System.Globalization;
using System.Net;
using System.Text;
using System.Text.Json;
using Nebula.Shared.Utils;
namespace Nebula.Shared.Services;
[ServiceRegister]
public class RestService
{
private readonly HttpClient _client = new();
private readonly DebugService _debug;
private readonly JsonSerializerOptions _serializerOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
WriteIndented = true
};
public RestService(DebugService debug)
{
_debug = debug;
}
public async Task<RestResult<T>> GetAsync<T>(Uri uri, CancellationToken cancellationToken)
{
_debug.Debug("GET " + uri);
try
{
var response = await _client.GetAsync(uri, cancellationToken);
return await ReadResult<T>(response, cancellationToken);
}
catch (Exception ex)
{
_debug.Error("ERROR WHILE CONNECTION " + uri + ": " + ex.Message);
return new RestResult<T>(default, ex.Message, HttpStatusCode.RequestTimeout);
}
}
public async Task<T> GetAsyncDefault<T>(Uri uri, T defaultValue, CancellationToken cancellationToken)
{
var result = await GetAsync<T>(uri, cancellationToken);
return result.Value ?? defaultValue;
}
public async Task<RestResult<K>> PostAsync<K, T>(T information, Uri uri, CancellationToken cancellationToken)
{
_debug.Debug("POST " + uri);
try
{
var json = JsonSerializer.Serialize(information, _serializerOptions);
var content = new StringContent(json, Encoding.UTF8, "application/json");
var response = await _client.PostAsync(uri, content, cancellationToken);
return await ReadResult<K>(response, cancellationToken);
}
catch (Exception ex)
{
_debug.Debug("ERROR " + ex.Message);
return new RestResult<K>(default, ex.Message, HttpStatusCode.RequestTimeout);
}
}
public async Task<RestResult<T>> PostAsync<T>(Stream stream, Uri uri, CancellationToken cancellationToken)
{
_debug.Debug("POST " + uri);
try
{
using var multipartFormContent =
new MultipartFormDataContent("Upload----" + DateTime.Now.ToString(CultureInfo.InvariantCulture));
multipartFormContent.Add(new StreamContent(stream), "formFile", "image.png");
var response = await _client.PostAsync(uri, multipartFormContent, cancellationToken);
return await ReadResult<T>(response, cancellationToken);
}
catch (Exception ex)
{
_debug.Error("ERROR " + ex.Message);
if (ex.StackTrace != null) _debug.Error(ex.StackTrace);
return new RestResult<T>(default, ex.Message, HttpStatusCode.RequestTimeout);
}
}
public async Task<RestResult<T>> DeleteAsync<T>(Uri uri, CancellationToken cancellationToken)
{
_debug.Debug("DELETE " + uri);
try
{
var response = await _client.DeleteAsync(uri, cancellationToken);
return await ReadResult<T>(response, cancellationToken);
}
catch (Exception ex)
{
_debug.Debug("ERROR " + ex.Message);
return new RestResult<T>(default, ex.Message, HttpStatusCode.RequestTimeout);
}
}
private async Task<RestResult<T>> ReadResult<T>(HttpResponseMessage response, CancellationToken cancellationToken)
{
var content = await response.Content.ReadAsStringAsync(cancellationToken);
if (response.IsSuccessStatusCode)
{
_debug.Debug($"SUCCESSFUL GET CONTENT {typeof(T)}");
if (typeof(T) == typeof(RawResult))
return (new RestResult<RawResult>(new RawResult(content), null, response.StatusCode) as RestResult<T>)!;
return new RestResult<T>(await response.Content.AsJson<T>(), null,
response.StatusCode);
}
_debug.Error("ERROR " + response.StatusCode + " " + content);
return new RestResult<T>(default, "response code:" + response.StatusCode, response.StatusCode);
}
}
public class RestResult<T>
{
public string Message = "Ok";
public HttpStatusCode StatusCode;
public T? Value;
public RestResult(T? value, string? message, HttpStatusCode statusCode)
{
Value = value;
if (message != null) Message = message;
StatusCode = statusCode;
}
public static implicit operator T?(RestResult<T> result)
{
return result.Value;
}
}
public class RawResult
{
public string Result;
public RawResult(string result)
{
Result = result;
}
public static implicit operator string(RawResult result)
{
return result.Result;
}
}

View File

@@ -0,0 +1,126 @@
using Nebula.Shared.Models;
using Robust.LoaderApi;
namespace Nebula.Shared.Services;
[ServiceRegister]
public sealed class RunnerService(
ContentService contentService,
DebugService debugService,
ConfigurationService varService,
FileService fileService,
EngineService engineService,
AssemblyService assemblyService,
AuthService authService,
PopupMessageService popupMessageService,
CancellationService cancellationService)
: IRedialApi
{
public async Task PrepareRun(RobustUrl url)
{
var buildInfo = await contentService.GetBuildInfo(url, cancellationService.Token);
await PrepareRun(buildInfo, cancellationService.Token);
}
public async Task PrepareRun(RobustBuildInfo buildInfo, CancellationToken cancellationToken)
{
debugService.Log("Prepare Content!");
var engine = await engineService.EnsureEngine(buildInfo.BuildInfo.Build.EngineVersion);
if (engine is null)
throw new Exception("Engine version is not usable: " + buildInfo.BuildInfo.Build.EngineVersion);
await contentService.EnsureItems(buildInfo.RobustManifestInfo, cancellationToken);
await engineService.EnsureEngineModules("Robust.Client.WebView", buildInfo.BuildInfo.Build.EngineVersion);
}
public async Task Run(string[] runArgs, RobustBuildInfo buildInfo, IRedialApi redialApi,
CancellationToken cancellationToken)
{
debugService.Log("Start Content!");
var engine = await engineService.EnsureEngine(buildInfo.BuildInfo.Build.EngineVersion);
if (engine is null)
throw new Exception("Engine version is not usable: " + buildInfo.BuildInfo.Build.EngineVersion);
await contentService.EnsureItems(buildInfo.RobustManifestInfo, cancellationToken);
var extraMounts = new List<ApiMount>
{
new(fileService.HashApi, "/")
};
var module =
await engineService.EnsureEngineModules("Robust.Client.WebView", buildInfo.BuildInfo.Build.EngineVersion);
if (module is not null)
extraMounts.Add(new ApiMount(module, "/"));
var args = new MainArgs(runArgs, engine, redialApi, extraMounts);
if (!assemblyService.TryOpenAssembly(varService.GetConfigValue(CurrentConVar.RobustAssemblyName)!, engine, out var clientAssembly))
throw new Exception("Unable to locate Robust.Client.dll in engine build!");
if (!assemblyService.TryGetLoader(clientAssembly, out var loader))
return;
await Task.Run(() => loader.Main(args), cancellationToken);
}
public async Task RunGame(string urlraw)
{
var url = new RobustUrl(urlraw);
using var cancelTokenSource = new CancellationTokenSource();
var buildInfo = await contentService.GetBuildInfo(url, cancelTokenSource.Token);
var account = authService.SelectedAuth;
if (account is null)
{
popupMessageService.Popup("Error! Auth is required!");
return;
}
if (buildInfo.BuildInfo.Auth.Mode != "Disabled")
{
Environment.SetEnvironmentVariable("ROBUST_AUTH_TOKEN", account.Token.Token);
Environment.SetEnvironmentVariable("ROBUST_AUTH_USERID", account.UserId.ToString());
Environment.SetEnvironmentVariable("ROBUST_AUTH_PUBKEY", buildInfo.BuildInfo.Auth.PublicKey);
Environment.SetEnvironmentVariable("ROBUST_AUTH_SERVER", account.AuthLoginPassword.AuthServer);
}
var args = new List<string>
{
// Pass username to launched client.
// We don't load username from client_config.toml when launched via launcher.
"--username", account.AuthLoginPassword.Login,
// Tell game we are launcher
"--cvar", "launch.launcher=true"
};
var connectionString = url.ToString();
if (!string.IsNullOrEmpty(buildInfo.BuildInfo.ConnectAddress))
connectionString = buildInfo.BuildInfo.ConnectAddress;
// We are using the launcher. Don't show main menu etc..
// Note: --launcher also implied --connect.
// For this reason, content bundles do not set --launcher.
args.Add("--launcher");
args.Add("--connect-address");
args.Add(connectionString);
args.Add("--ss14-address");
args.Add(url.ToString());
debugService.Debug("Connect to " + url.ToString() + " " + account.AuthLoginPassword.AuthServer);
await Run(args.ToArray(), buildInfo, this, cancelTokenSource.Token);
}
public async void Redial(Uri uri, string text = "")
{
//await RunGame(uri.ToString());
}
}

View File

@@ -0,0 +1,138 @@
using System.Diagnostics;
namespace Nebula.Shared.Utils;
public sealed class BandwidthStream : Stream
{
private const int NumSeconds = 8;
private const int BucketDivisor = 2;
private const int BucketsPerSecond = 2 << BucketDivisor;
// TotalBuckets MUST be power of two!
private const int TotalBuckets = NumSeconds * BucketsPerSecond;
private readonly Stream _baseStream;
private readonly long[] _buckets;
private readonly Stopwatch _stopwatch;
private long _bucketIndex;
public BandwidthStream(Stream baseStream)
{
_stopwatch = Stopwatch.StartNew();
_baseStream = baseStream;
_buckets = new long[TotalBuckets];
}
public override bool CanRead => _baseStream.CanRead;
public override bool CanSeek => _baseStream.CanSeek;
public override bool CanWrite => _baseStream.CanWrite;
public override long Length => _baseStream.Length;
public override long Position
{
get => _baseStream.Position;
set => _baseStream.Position = value;
}
private void TrackBandwidth(long value)
{
const int bucketMask = TotalBuckets - 1;
var bucketIdx = CurBucketIdx();
// Increment to bucket idx, clearing along the way.
if (bucketIdx != _bucketIndex)
{
var diff = bucketIdx - _bucketIndex;
if (diff > TotalBuckets)
for (var i = _bucketIndex; i < bucketIdx; i++)
_buckets[i & bucketMask] = 0;
else
// We managed to skip so much time the whole buffer is empty.
Array.Clear(_buckets);
_bucketIndex = bucketIdx;
}
// Write value.
_buckets[bucketIdx & bucketMask] += value;
}
private long CurBucketIdx()
{
var elapsed = _stopwatch.Elapsed.TotalSeconds;
return (long)(elapsed / BucketsPerSecond);
}
public long CalcCurrentAvg()
{
var sum = 0L;
for (var i = 0; i < TotalBuckets; i++) sum += _buckets[i];
return sum >> BucketDivisor;
}
public override void Flush()
{
_baseStream.Flush();
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _baseStream.FlushAsync(cancellationToken);
}
protected override void Dispose(bool disposing)
{
if (disposing)
_baseStream.Dispose();
}
public override ValueTask DisposeAsync()
{
return _baseStream.DisposeAsync();
}
public override int Read(byte[] buffer, int offset, int count)
{
var read = _baseStream.Read(buffer, offset, count);
TrackBandwidth(read);
return read;
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
var read = await base.ReadAsync(buffer, cancellationToken);
TrackBandwidth(read);
return read;
}
public override long Seek(long offset, SeekOrigin origin)
{
return _baseStream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_baseStream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
_baseStream.Write(buffer, offset, count);
TrackBandwidth(count);
}
public override async ValueTask WriteAsync(
ReadOnlyMemory<byte> buffer,
CancellationToken cancellationToken = default)
{
await _baseStream.WriteAsync(buffer, cancellationToken);
TrackBandwidth(buffer.Length);
}
}

View File

@@ -0,0 +1,26 @@
using System.Windows.Input;
namespace Nebula.Shared.Utils;
public class DelegateCommand<T> : ICommand
{
private readonly Action<T> _func;
public readonly Ref<T> TRef = new();
public DelegateCommand(Action<T> func)
{
_func = func;
}
public bool CanExecute(object? parameter)
{
return true;
}
public void Execute(object? parameter)
{
_func(TRef.Value);
}
public event EventHandler? CanExecuteChanged;
}

View File

@@ -0,0 +1,37 @@
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text.Json;
namespace Nebula.Shared.Utils;
public static class Helper
{
public static readonly JsonSerializerOptions JsonWebOptions = new(JsonSerializerDefaults.Web);
public static void OpenBrowser(string url)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Process.Start(new ProcessStartInfo("cmd", $"/c start {url}"));
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
Process.Start("xdg-open", url);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
Process.Start("open", url);
}
else
{
}
}
public static async Task<T> AsJson<T>(this HttpContent content) where T : notnull
{
var str = await content.ReadAsStringAsync();
return JsonSerializer.Deserialize<T>(str, JsonWebOptions) ??
throw new JsonException("AsJson: did not expect null response");
}
}

View File

@@ -0,0 +1,118 @@
using System.Diagnostics.CodeAnalysis;
using System.Text;
using Nebula.Shared.Models;
namespace Nebula.Shared.Utils;
public class ManifestReader : StreamReader
{
public const int BufferSize = 128;
public ManifestReader(Stream stream) : base(stream)
{
ReadManifestVersion();
}
public ManifestReader(Stream stream, bool detectEncodingFromByteOrderMarks) : base(stream,
detectEncodingFromByteOrderMarks)
{
ReadManifestVersion();
}
public ManifestReader(Stream stream, Encoding encoding) : base(stream, encoding)
{
ReadManifestVersion();
}
public ManifestReader(Stream stream, Encoding encoding, bool detectEncodingFromByteOrderMarks) : base(stream,
encoding, detectEncodingFromByteOrderMarks)
{
ReadManifestVersion();
}
public ManifestReader(Stream stream, Encoding encoding, bool detectEncodingFromByteOrderMarks, int bufferSize) :
base(stream, encoding, detectEncodingFromByteOrderMarks, bufferSize)
{
ReadManifestVersion();
}
public ManifestReader(Stream stream, Encoding? encoding = null, bool detectEncodingFromByteOrderMarks = true,
int bufferSize = -1, bool leaveOpen = false) : base(stream, encoding, detectEncodingFromByteOrderMarks,
bufferSize, leaveOpen)
{
ReadManifestVersion();
}
public ManifestReader(string path) : base(path)
{
ReadManifestVersion();
}
public ManifestReader(string path, bool detectEncodingFromByteOrderMarks) : base(path,
detectEncodingFromByteOrderMarks)
{
ReadManifestVersion();
}
public ManifestReader(string path, FileStreamOptions options) : base(path, options)
{
ReadManifestVersion();
}
public ManifestReader(string path, Encoding encoding) : base(path, encoding)
{
ReadManifestVersion();
}
public ManifestReader(string path, Encoding encoding, bool detectEncodingFromByteOrderMarks) : base(path, encoding,
detectEncodingFromByteOrderMarks)
{
ReadManifestVersion();
}
public ManifestReader(string path, Encoding encoding, bool detectEncodingFromByteOrderMarks, int bufferSize) : base(
path, encoding, detectEncodingFromByteOrderMarks, bufferSize)
{
ReadManifestVersion();
}
public ManifestReader(string path, Encoding encoding, bool detectEncodingFromByteOrderMarks,
FileStreamOptions options) : base(path, encoding, detectEncodingFromByteOrderMarks, options)
{
ReadManifestVersion();
}
public string ManifestVersion { get; private set; } = "";
public int CurrentId { get; private set; }
private void ReadManifestVersion()
{
ManifestVersion = ReadLine();
}
public RobustManifestItem? ReadItem()
{
var line = ReadLine();
if (line == null) return null;
var splited = line.Split(" ");
return new RobustManifestItem(splited[0], line.Substring(splited[0].Length + 1), CurrentId++);
}
public bool TryReadItem([NotNullWhen(true)] out RobustManifestItem? item)
{
item = ReadItem();
return item != null;
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
CurrentId = 0;
}
public new void DiscardBufferedData()
{
base.DiscardBufferedData();
CurrentId = 0;
}
}

View File

@@ -0,0 +1,6 @@
namespace Nebula.Shared.Utils;
public class Ref<T>
{
public T Value = default!;
}

View File

@@ -0,0 +1,118 @@
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Nebula.Shared.Utils;
public static class RidUtility
{
public static string? FindBestRid(ICollection<string> runtimes, string? currentRid = null)
{
var catalog = LoadRidCatalog();
if (currentRid == null)
{
var reportedRid = RuntimeInformation.RuntimeIdentifier;
if (!catalog.Runtimes.ContainsKey(reportedRid))
{
currentRid = GuessRid();
Console.WriteLine(".NET reported unknown RID: {0}, guessing: {1}", reportedRid, currentRid);
}
else
{
currentRid = reportedRid;
}
}
// Breadth-first search.
var q = new Queue<string>();
if (!catalog.Runtimes.TryGetValue(currentRid, out var root))
// RID doesn't exist in catalog???
return null;
root.Discovered = true;
q.Enqueue(currentRid);
while (q.TryDequeue(out var v))
{
if (runtimes.Contains(v)) return v;
foreach (var w in catalog.Runtimes[v].Imports)
{
var r = catalog.Runtimes[w];
if (!r.Discovered)
{
q.Enqueue(w);
r.Discovered = true;
}
}
}
return null;
}
private static string GuessRid()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
return RuntimeInformation.ProcessArchitecture switch
{
Architecture.X86 => "linux-x86",
Architecture.X64 => "linux-x64",
Architecture.Arm => "linux-arm",
Architecture.Arm64 => "linux-arm64",
_ => "unknown"
};
if (RuntimeInformation.IsOSPlatform(OSPlatform.FreeBSD))
return RuntimeInformation.ProcessArchitecture switch
{
Architecture.X64 => "freebsd-x64",
_ => "unknown"
};
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return RuntimeInformation.ProcessArchitecture switch
{
Architecture.X86 => "win-x86",
Architecture.X64 => "win-x64",
Architecture.Arm => "win-arm",
Architecture.Arm64 => "win-arm64",
_ => "unknown"
};
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
return RuntimeInformation.ProcessArchitecture switch
{
Architecture.X64 => "osx-x64",
Architecture.Arm64 => "osx-arm64",
_ => "unknown"
};
return "unknown";
}
private static RidCatalog LoadRidCatalog()
{
using var stream = typeof(RidCatalog).Assembly.GetManifestResourceStream("Utility.runtime.json")!;
var ms = new MemoryStream();
stream.CopyTo(ms);
return JsonSerializer.Deserialize<RidCatalog>(ms.GetBuffer().AsSpan(0, (int)ms.Length))!;
}
#pragma warning disable 649
private sealed class RidCatalog
{
[JsonInclude] [JsonPropertyName("runtimes")]
public Dictionary<string, Runtime> Runtimes = default!;
public class Runtime
{
public bool Discovered;
[JsonInclude] [JsonPropertyName("#import")]
public string[] Imports = default!;
}
}
#pragma warning restore 649
}

View File

@@ -0,0 +1,54 @@
using System.Buffers;
namespace Nebula.Shared.Utils;
public static class StreamHelper
{
public static async ValueTask<byte[]> ReadExactAsync(this Stream stream, int amount, CancellationToken? cancel)
{
var data = new byte[amount];
await ReadExactAsync(stream, data, cancel);
return data;
}
public static async ValueTask ReadExactAsync(this Stream stream, Memory<byte> into, CancellationToken? cancel)
{
while (into.Length > 0)
{
var read = await stream.ReadAsync(into);
// Check EOF.
if (read == 0)
throw new EndOfStreamException();
into = into[read..];
}
}
public static async Task CopyAmountToAsync(
this Stream stream,
Stream to,
int amount,
int bufferSize,
CancellationToken cancel)
{
var buffer = ArrayPool<byte>.Shared.Rent(bufferSize);
while (amount > 0)
{
Memory<byte> readInto = buffer;
if (amount < readInto.Length)
readInto = readInto[..amount];
var read = await stream.ReadAsync(readInto, cancel);
if (read == 0)
throw new EndOfStreamException();
amount -= read;
readInto = readInto[..read];
await to.WriteAsync(readInto, cancel);
}
}
}

View File

@@ -0,0 +1,140 @@
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Web;
using Nebula.Shared.Models;
namespace Nebula.Shared.Utils;
public static class UriHelper
{
public const string SchemeSs14 = "ss14";
// ReSharper disable once InconsistentNaming
public const string SchemeSs14s = "ss14s";
/// <summary>
/// Parses an <c>ss14://</c> or <c>ss14s://</c> URI,
/// defaulting to <c>ss14://</c> if no scheme is specified.
/// </summary>
[Pure]
public static Uri ParseSs14Uri(string address)
{
if (!TryParseSs14Uri(address, out var uri)) throw new FormatException("Not a valid SS14 URI");
return uri;
}
[Pure]
public static bool TryParseSs14Uri(string address, [NotNullWhen(true)] out Uri? uri)
{
if (!address.Contains("://")) address = "ss14://" + address;
if (!Uri.TryCreate(address, UriKind.Absolute, out uri)) return false;
if (uri.Scheme != SchemeSs14 && uri.Scheme != SchemeSs14s) return false;
if (string.IsNullOrWhiteSpace(uri.Host))
return false;
return true;
}
/// <summary>
/// Gets the <c>http://</c> or <c>https://</c> API address for a server address.
/// </summary>
[Pure]
public static Uri GetServerApiAddress(Uri serverAddress)
{
var dataScheme = serverAddress.Scheme switch
{
"ss14" => Uri.UriSchemeHttp,
"ss14s" => Uri.UriSchemeHttps,
_ => throw new ArgumentException($"Wrong URI scheme: {serverAddress.Scheme}")
};
var builder = new UriBuilder(serverAddress)
{
Scheme = dataScheme
};
// No port specified.
// Default port for ss14:// is 1212, for ss14s:// it's 443 (HTTPS)
if (serverAddress.IsDefaultPort && serverAddress.Scheme == SchemeSs14) builder.Port = 1212;
if (!builder.Path.EndsWith('/')) builder.Path += "/";
return builder.Uri;
}
/// <summary>
/// Gets the <c>/status</c> HTTP address for a server address.
/// </summary>
[Pure]
public static Uri GetServerStatusAddress(string serverAddress)
{
return GetServerStatusAddress(ParseSs14Uri(serverAddress));
}
/// <summary>
/// Gets the <c>/status</c> HTTP address for an ss14 uri.
/// </summary>
[Pure]
public static Uri GetServerStatusAddress(Uri serverAddress)
{
return new Uri(GetServerApiAddress(serverAddress), "status");
}
/// <summary>
/// Gets the <c>/info</c> HTTP address for a server address.
/// </summary>
[Pure]
public static Uri GetServerInfoAddress(string serverAddress)
{
return GetServerInfoAddress(ParseSs14Uri(serverAddress));
}
/// <summary>
/// Gets the <c>/info</c> HTTP address for an ss14 uri.
/// </summary>
[Pure]
public static Uri GetServerInfoAddress(Uri serverAddress)
{
return new Uri(GetServerApiAddress(serverAddress), "info");
}
/// <summary>
/// Gets the <c>/client.zip</c> HTTP address for a server address.
/// This is not necessarily the actual client ZIP address.
/// </summary>
[Pure]
public static Uri GetServerSelfhostedClientZipAddress(string serverAddress)
{
return GetServerSelfhostedClientZipAddress(ParseSs14Uri(serverAddress));
}
/// <summary>
/// Gets the <c>/client.zip</c> HTTP address for an ss14 uri.
/// This is not necessarily the actual client ZIP address.
/// </summary>
[Pure]
public static Uri GetServerSelfhostedClientZipAddress(Uri serverAddress)
{
return new Uri(GetServerApiAddress(serverAddress), "client.zip");
}
[Pure]
public static Uri AddParameter(this Uri url, string paramName, string paramValue)
{
var uriBuilder = new UriBuilder(url);
var query = HttpUtility.ParseQueryString(uriBuilder.Query);
query[paramName] = paramValue;
uriBuilder.Query = query.ToString();
return uriBuilder.Uri;
}
public static RobustUrl ToRobustUrl(this string str)
{
return new RobustUrl(str);
}
}

490
Nebula.Shared/Utils/ZStd.cs Normal file
View File

@@ -0,0 +1,490 @@
using System.Buffers;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using SharpZstd.Interop;
using static SharpZstd.Interop.Zstd;
namespace Nebula.Shared.Utils;
public static class ZStd
{
public static int CompressBound(int length)
{
return (int)ZSTD_compressBound((nuint)length);
}
[ModuleInitializer]
public static void InitZStd()
{
return;
NativeLibrary.SetDllImportResolver(
typeof(Zstd).Assembly,
ResolveZstd
);
}
private static IntPtr ResolveZstd(string name, Assembly assembly, DllImportSearchPath? path)
{
if (name == "zstd" && OperatingSystem.IsLinux())
{
if (NativeLibrary.TryLoad("zstd.so", assembly, path, out var handle))
return handle;
// Try some extra paths too worst case.
if (NativeLibrary.TryLoad("libzstd.so.1", assembly, path, out handle))
return handle;
if (NativeLibrary.TryLoad("libzstd.so", assembly, path, out handle))
return handle;
}
return IntPtr.Zero;
}
}
public sealed unsafe class ZStdCCtx : IDisposable
{
public ZStdCCtx()
{
Context = ZSTD_createCCtx();
}
public ZSTD_CCtx* Context { get; private set; }
private bool Disposed => Context == null;
public void Dispose()
{
if (Disposed)
return;
ZSTD_freeCCtx(Context);
Context = null;
GC.SuppressFinalize(this);
}
public void SetParameter(ZSTD_cParameter parameter, int value)
{
CheckDisposed();
ZSTD_CCtx_setParameter(Context, parameter, value);
}
public int Compress(Span<byte> destination, Span<byte> source, int compressionLevel = ZSTD_CLEVEL_DEFAULT)
{
CheckDisposed();
fixed (byte* dst = destination)
fixed (byte* src = source)
{
var ret = ZSTD_compressCCtx(
Context,
dst, (nuint)destination.Length,
src, (nuint)source.Length,
compressionLevel);
ZStdException.ThrowIfError(ret);
return (int)ret;
}
}
~ZStdCCtx()
{
Dispose();
}
private void CheckDisposed()
{
if (Disposed)
throw new ObjectDisposedException(nameof(ZStdCCtx));
}
}
public sealed unsafe class ZStdDCtx : IDisposable
{
public ZStdDCtx()
{
Context = ZSTD_createDCtx();
}
public ZSTD_DCtx* Context { get; private set; }
private bool Disposed => Context == null;
public void Dispose()
{
if (Disposed)
return;
ZSTD_freeDCtx(Context);
Context = null;
GC.SuppressFinalize(this);
}
public void SetParameter(ZSTD_dParameter parameter, int value)
{
CheckDisposed();
ZSTD_DCtx_setParameter(Context, parameter, value);
}
public int Decompress(Span<byte> destination, Span<byte> source)
{
CheckDisposed();
fixed (byte* dst = destination)
fixed (byte* src = source)
{
var ret = ZSTD_decompressDCtx(Context, dst, (nuint)destination.Length, src, (nuint)source.Length);
ZStdException.ThrowIfError(ret);
return (int)ret;
}
}
~ZStdDCtx()
{
Dispose();
}
private void CheckDisposed()
{
if (Disposed)
throw new ObjectDisposedException(nameof(ZStdDCtx));
}
}
[Serializable]
public class ZStdException : Exception
{
public ZStdException()
{
}
public ZStdException(string message) : base(message)
{
}
public ZStdException(string message, Exception inner) : base(message, inner)
{
}
public static unsafe ZStdException FromCode(nuint code)
{
return new ZStdException(Marshal.PtrToStringUTF8((IntPtr)ZSTD_getErrorName(code))!);
}
public static void ThrowIfError(nuint code)
{
if (ZSTD_isError(code) != 0)
throw FromCode(code);
}
}
public sealed class ZStdDecompressStream : Stream
{
private readonly Stream _baseStream;
private readonly byte[] _buffer;
private readonly unsafe ZSTD_DCtx* _ctx;
private readonly bool _ownStream;
private int _bufferPos;
private int _bufferSize;
private bool _disposed;
public unsafe ZStdDecompressStream(Stream baseStream, bool ownStream = true)
{
_baseStream = baseStream;
_ownStream = ownStream;
_ctx = ZSTD_createDCtx();
_buffer = ArrayPool<byte>.Shared.Rent((int)ZSTD_DStreamInSize());
}
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
protected override unsafe void Dispose(bool disposing)
{
if (_disposed)
return;
_disposed = true;
ZSTD_freeDCtx(_ctx);
if (disposing)
{
if (_ownStream)
_baseStream.Dispose();
ArrayPool<byte>.Shared.Return(_buffer);
}
}
public override void Flush()
{
ThrowIfDisposed();
_baseStream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
public override int ReadByte()
{
Span<byte> buf = stackalloc byte[1];
return Read(buf) == 0 ? -1 : buf[0];
}
public override unsafe int Read(Span<byte> buffer)
{
ThrowIfDisposed();
do
{
if (_bufferSize == 0 || _bufferPos == _bufferSize)
{
_bufferPos = 0;
_bufferSize = _baseStream.Read(_buffer);
if (_bufferSize == 0)
return 0;
}
fixed (byte* inputPtr = _buffer)
fixed (byte* outputPtr = buffer)
{
var outputBuf = new ZSTD_outBuffer { dst = outputPtr, pos = 0, size = (nuint)buffer.Length };
var inputBuf = new ZSTD_inBuffer { src = inputPtr, pos = (nuint)_bufferPos, size = (nuint)_bufferSize };
var ret = ZSTD_decompressStream(_ctx, &outputBuf, &inputBuf);
_bufferPos = (int)inputBuf.pos;
ZStdException.ThrowIfError(ret);
if (outputBuf.pos > 0)
return (int)outputBuf.pos;
}
} while (true);
}
public override async ValueTask<int> ReadAsync(
Memory<byte> buffer,
CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
do
{
if (_bufferSize == 0 || _bufferPos == _bufferSize)
{
_bufferPos = 0;
_bufferSize = await _baseStream.ReadAsync(_buffer, cancellationToken);
if (_bufferSize == 0)
return 0;
}
var ret = DecompressChunk(this, buffer.Span);
if (ret > 0)
return (int)ret;
} while (true);
static unsafe nuint DecompressChunk(ZStdDecompressStream stream, Span<byte> buffer)
{
fixed (byte* inputPtr = stream._buffer)
fixed (byte* outputPtr = buffer)
{
ZSTD_outBuffer outputBuf = default;
outputBuf.dst = outputPtr;
outputBuf.pos = 0;
outputBuf.size = (nuint)buffer.Length;
ZSTD_inBuffer inputBuf = default;
inputBuf.src = inputPtr;
inputBuf.pos = (nuint)stream._bufferPos;
inputBuf.size = (nuint)stream._bufferSize;
var ret = ZSTD_decompressStream(stream._ctx, &outputBuf, &inputBuf);
stream._bufferPos = (int)inputBuf.pos;
ZStdException.ThrowIfError(ret);
return outputBuf.pos;
}
}
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
private void ThrowIfDisposed()
{
if (_disposed)
throw new ObjectDisposedException(nameof(ZStdDecompressStream));
}
}
public sealed class ZStdCompressStream : Stream
{
private readonly Stream _baseStream;
private readonly byte[] _buffer;
private readonly unsafe ZSTD_CCtx* _ctx;
private readonly bool _ownStream;
private int _bufferPos;
private bool _disposed;
public unsafe ZStdCompressStream(Stream baseStream, bool ownStream = true)
{
_ctx = ZSTD_createCCtx();
_baseStream = baseStream;
_ownStream = ownStream;
_buffer = ArrayPool<byte>.Shared.Rent((int)ZSTD_CStreamOutSize());
}
public override bool CanRead => false;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length => throw new NotSupportedException();
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override void Flush()
{
FlushInternal(ZSTD_EndDirective.ZSTD_e_flush);
}
public void FlushEnd()
{
FlushInternal(ZSTD_EndDirective.ZSTD_e_end);
}
private unsafe void FlushInternal(ZSTD_EndDirective directive)
{
fixed (byte* outPtr = _buffer)
{
ZSTD_outBuffer outBuf = default;
outBuf.size = (nuint)_buffer.Length;
outBuf.pos = (nuint)_bufferPos;
outBuf.dst = outPtr;
ZSTD_inBuffer inBuf = default;
while (true)
{
var err = ZSTD_compressStream2(_ctx, &outBuf, &inBuf, directive);
ZStdException.ThrowIfError(err);
_bufferPos = (int)outBuf.pos;
_baseStream.Write(_buffer.AsSpan(0, (int)outBuf.pos));
_bufferPos = 0;
outBuf.pos = 0;
if (err == 0)
break;
}
}
_baseStream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}
public override unsafe void Write(ReadOnlySpan<byte> buffer)
{
ThrowIfDisposed();
fixed (byte* outPtr = _buffer)
fixed (byte* inPtr = buffer)
{
ZSTD_outBuffer outBuf = default;
outBuf.size = (nuint)_buffer.Length;
outBuf.pos = (nuint)_bufferPos;
outBuf.dst = outPtr;
ZSTD_inBuffer inBuf = default;
inBuf.pos = 0;
inBuf.size = (nuint)buffer.Length;
inBuf.src = inPtr;
while (true)
{
var err = ZSTD_compressStream2(_ctx, &outBuf, &inBuf, ZSTD_EndDirective.ZSTD_e_continue);
ZStdException.ThrowIfError(err);
_bufferPos = (int)outBuf.pos;
if (inBuf.pos >= inBuf.size)
break;
// Not all input data consumed. Flush output buffer and continue.
_baseStream.Write(_buffer.AsSpan(0, (int)outBuf.pos));
_bufferPos = 0;
outBuf.pos = 0;
}
}
}
protected override unsafe void Dispose(bool disposing)
{
base.Dispose(disposing);
if (_disposed)
return;
_disposed = true;
ZSTD_freeCCtx(_ctx);
if (disposing)
{
if (_ownStream)
_baseStream.Dispose();
ArrayPool<byte>.Shared.Return(_buffer);
}
}
private void ThrowIfDisposed()
{
if (_disposed)
throw new ObjectDisposedException(nameof(ZStdCompressStream));
}
}

File diff suppressed because it is too large Load Diff