Server ban exemption system (#15076)

This commit is contained in:
Pieter-Jan Briers
2023-04-03 02:24:55 +02:00
committed by GitHub
parent e037d12899
commit c8e90e561b
26 changed files with 8681 additions and 135 deletions

View File

@@ -332,6 +332,52 @@ namespace Content.Server.Database
public abstract Task AddServerBanAsync(ServerBanDef serverBan);
public abstract Task AddServerUnbanAsync(ServerUnbanDef serverUnban);
protected static async Task<ServerBanExemptFlags?> GetBanExemptionCore(DbGuard db, NetUserId? userId)
{
if (userId == null)
return null;
var exemption = await db.DbContext.BanExemption
.SingleOrDefaultAsync(e => e.UserId == userId.Value.UserId);
return exemption?.Flags;
}
public async Task UpdateBanExemption(NetUserId userId, ServerBanExemptFlags flags)
{
await using var db = await GetDb();
if (flags == 0)
{
// Delete whatever is there.
await db.DbContext.BanExemption.Where(u => u.UserId == userId.UserId).ExecuteDeleteAsync();
return;
}
var exemption = await db.DbContext.BanExemption.SingleOrDefaultAsync(u => u.UserId == userId.UserId);
if (exemption == null)
{
exemption = new ServerBanExemption
{
UserId = userId
};
db.DbContext.BanExemption.Add(exemption);
}
exemption.Flags = flags;
await db.DbContext.SaveChangesAsync();
}
public async Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId)
{
await using var db = await GetDb();
var flags = await GetBanExemptionCore(db, userId);
return flags ?? ServerBanExemptFlags.None;
}
#endregion
#region Role Bans
@@ -985,6 +1031,5 @@ namespace Content.Server.Database
public abstract ValueTask DisposeAsync();
}
}
}

View File

@@ -84,6 +84,23 @@ namespace Content.Server.Database
Task AddServerBanAsync(ServerBanDef serverBan);
Task AddServerUnbanAsync(ServerUnbanDef serverBan);
/// <summary>
/// Update ban exemption information for a player.
/// </summary>
/// <remarks>
/// Database rows are automatically created and removed when appropriate.
/// </remarks>
/// <param name="userId">The user to update</param>
/// <param name="flags">The new ban exemption flags.</param>
Task UpdateBanExemption(NetUserId userId, ServerBanExemptFlags flags);
/// <summary>
/// Get current ban exemption flags for a user
/// </summary>
/// <returns><see cref="ServerBanExemptFlags.None"/> if the user is not exempt from any bans.</returns>
Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId);
#endregion
#region Role Bans
@@ -353,6 +370,18 @@ namespace Content.Server.Database
return _db.AddServerUnbanAsync(serverUnban);
}
public Task UpdateBanExemption(NetUserId userId, ServerBanExemptFlags flags)
{
DbWriteOpsMetric.Inc();
return _db.UpdateBanExemption(userId, flags);
}
public Task<ServerBanExemptFlags> GetBanExemption(NetUserId userId)
{
DbReadOpsMetric.Inc();
return _db.GetBanExemption(userId);
}
#region Role Ban
public Task<ServerRoleBanDef?> GetServerRoleBanAsync(int id)
{
@@ -742,10 +771,10 @@ namespace Content.Server.Database
return true;
}
public IDisposable BeginScope<TState>(TState state)
public IDisposable? BeginScope<TState>(TState state) where TState : notnull
{
// TODO: this
return null!;
return null;
}
}
}

View File

@@ -6,6 +6,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Robust.Shared.Network;
using Robust.Shared.Utility;
namespace Content.Server.Database
{
@@ -58,7 +59,8 @@ namespace Content.Server.Database
await using var db = await GetDbImpl();
var query = MakeBanLookupQuery(address, userId, hwId, db, includeUnbanned: false)
var exempt = await GetBanExemptionCore(db, userId);
var query = MakeBanLookupQuery(address, userId, hwId, db, includeUnbanned: false, exempt)
.OrderByDescending(b => b.BanTime);
var ban = await query.FirstOrDefaultAsync();
@@ -77,7 +79,8 @@ namespace Content.Server.Database
await using var db = await GetDbImpl();
var query = MakeBanLookupQuery(address, userId, hwId, db, includeUnbanned);
var exempt = await GetBanExemptionCore(db, userId);
var query = MakeBanLookupQuery(address, userId, hwId, db, includeUnbanned, exempt);
var queryBans = await query.ToArrayAsync();
var bans = new List<ServerBanDef>(queryBans.Length);
@@ -100,8 +103,11 @@ namespace Content.Server.Database
NetUserId? userId,
ImmutableArray<byte>? hwId,
DbGuardImpl db,
bool includeUnbanned)
bool includeUnbanned,
ServerBanExemptFlags? exemptFlags)
{
DebugTools.Assert(!(address == null && userId == null && hwId == null));
IQueryable<ServerBan>? query = null;
if (userId is { } uid)
@@ -131,14 +137,22 @@ namespace Content.Server.Database
query = query == null ? newQ : query.Union(newQ);
}
DebugTools.Assert(
query != null,
"At least one filter item (IP/UserID/HWID) must have been given to make query not null.");
if (!includeUnbanned)
{
query = query?.Where(p =>
query = query.Where(p =>
p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.Now));
}
query = query!.Distinct();
return query;
if (exemptFlags is { } exempt)
{
query = query.Where(b => (b.ExemptFlags & exempt) == 0);
}
return query.Distinct();
}
private static ServerBanDef? ConvertBan(ServerBan? ban)

View File

@@ -68,9 +68,11 @@ namespace Content.Server.Database
{
await using var db = await GetDbImpl();
var exempt = await GetBanExemptionCore(db, userId);
// SQLite can't do the net masking stuff we need to match IP address ranges.
// So just pull down the whole list into memory.
var bans = await GetAllBans(db.SqliteDbContext, includeUnbanned: false);
var bans = await GetAllBans(db.SqliteDbContext, includeUnbanned: false, exempt);
return bans.FirstOrDefault(b => BanMatches(b, address, userId, hwId)) is { } foundBan
? ConvertBan(foundBan)
@@ -83,9 +85,11 @@ namespace Content.Server.Database
{
await using var db = await GetDbImpl();
var exempt = await GetBanExemptionCore(db, userId);
// SQLite can't do the net masking stuff we need to match IP address ranges.
// So just pull down the whole list into memory.
var queryBans = await GetAllBans(db.SqliteDbContext, includeUnbanned);
var queryBans = await GetAllBans(db.SqliteDbContext, includeUnbanned, exempt);
return queryBans
.Where(b => BanMatches(b, address, userId, hwId))
@@ -95,7 +99,8 @@ namespace Content.Server.Database
private static async Task<List<ServerBan>> GetAllBans(
SqliteServerDbContext db,
bool includeUnbanned)
bool includeUnbanned,
ServerBanExemptFlags? exemptFlags)
{
IQueryable<ServerBan> query = db.Ban.Include(p => p.Unban);
if (!includeUnbanned)
@@ -104,6 +109,11 @@ namespace Content.Server.Database
p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow));
}
if (exemptFlags is { } exempt)
{
query = query.Where(b => (b.ExemptFlags & exempt) == 0);
}
return await query.ToListAsync();
}