Use HWIDs for bans.

This commit is contained in:
Pieter-Jan Briers
2021-03-22 01:30:50 +01:00
parent 071362ed25
commit a321b4302e
19 changed files with 1589 additions and 181 deletions

View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Immutable;
using System.Net;
using Robust.Shared.Network;
@@ -7,6 +8,7 @@ namespace Content.Server.Database
public sealed class PlayerRecord
{
public NetUserId UserId { get; }
public ImmutableArray<byte>? HWId { get; }
public DateTimeOffset FirstSeenTime { get; }
public string LastSeenUserName { get; }
public DateTimeOffset LastSeenTime { get; }
@@ -17,13 +19,15 @@ namespace Content.Server.Database
DateTimeOffset firstSeenTime,
string lastSeenUserName,
DateTimeOffset lastSeenTime,
IPAddress lastSeenAddress)
IPAddress lastSeenAddress,
ImmutableArray<byte>? hwId)
{
UserId = userId;
FirstSeenTime = firstSeenTime;
LastSeenUserName = lastSeenUserName;
LastSeenTime = lastSeenTime;
LastSeenAddress = lastSeenAddress;
HWId = hwId;
}
}
}

View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Immutable;
using System.Net;
using Robust.Shared.Network;
@@ -11,6 +12,7 @@ namespace Content.Server.Database
public int? Id { get; }
public NetUserId? UserId { get; }
public (IPAddress address, int cidrMask)? Address { get; }
public ImmutableArray<byte>? HWId { get; }
public DateTimeOffset BanTime { get; }
public DateTimeOffset? ExpirationTime { get; }
@@ -22,15 +24,16 @@ namespace Content.Server.Database
int? id,
NetUserId? userId,
(IPAddress, int)? address,
ImmutableArray<byte>? hwId,
DateTimeOffset banTime,
DateTimeOffset? expirationTime,
string reason,
NetUserId? banningAdmin,
ServerUnbanDef? unban)
{
if (userId == null && address == null)
if (userId == null && address == null && hwId == null)
{
throw new ArgumentException("Must have a banned user, banned address, or both.");
throw new ArgumentException("Must have at least one of banned user, banned address or hardware ID");
}
if (address is {} addr && addr.Item1.IsIPv4MappedToIPv6)
@@ -43,6 +46,7 @@ namespace Content.Server.Database
Id = id;
UserId = userId;
Address = address;
HWId = hwId;
BanTime = banTime;
ExpirationTime = expirationTime;
Reason = reason;

View File

@@ -1,6 +1,7 @@
#nullable enable
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Net;
using System.Threading;
@@ -261,8 +262,12 @@ namespace Content.Server.Database
/// </summary>
/// <param name="address">The ip address of the user.</param>
/// <param name="userId">The id of the user.</param>
/// <param name="hwId">The HWId of the user.</param>
/// <returns>The user's latest received un-pardoned ban, or null if none exist.</returns>
public abstract Task<ServerBanDef?> GetServerBanAsync(IPAddress? address, NetUserId? userId);
public abstract Task<ServerBanDef?> GetServerBanAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId);
/// <summary>
/// Looks up an user's ban history.
@@ -271,8 +276,12 @@ namespace Content.Server.Database
/// </summary>
/// <param name="address">The ip address of the user.</param>
/// <param name="userId">The id of the user.</param>
/// <param name="hwId">The HWId of the user.</param>
/// <returns>The user's ban history.</returns>
public abstract Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, NetUserId? userId);
public abstract Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId);
public abstract Task AddServerBanAsync(ServerBanDef serverBan);
public abstract Task AddServerUnbanAsync(ServerUnbanDef serverUnban);
@@ -280,14 +289,22 @@ namespace Content.Server.Database
/*
* PLAYER RECORDS
*/
public abstract Task UpdatePlayerRecord(NetUserId userId, string userName, IPAddress address);
public abstract Task UpdatePlayerRecord(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId);
public abstract Task<PlayerRecord?> GetPlayerRecordByUserName(string userName, CancellationToken cancel);
public abstract Task<PlayerRecord?> GetPlayerRecordByUserId(NetUserId userId, CancellationToken cancel);
/*
* CONNECTION LOG
*/
public abstract Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address);
public abstract Task AddConnectionLogAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId);
/*
* ADMIN STUFF

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Net;
using System.Threading;
@@ -59,8 +60,12 @@ namespace Content.Server.Database
/// </summary>
/// <param name="address">The ip address of the user.</param>
/// <param name="userId">The id of the user.</param>
/// <param name="hwId">The hardware ID of the user.</param>
/// <returns>The user's latest received un-pardoned ban, or null if none exist.</returns>
Task<ServerBanDef?> GetServerBanAsync(IPAddress? address, NetUserId? userId);
Task<ServerBanDef?> GetServerBanAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId);
/// <summary>
/// Looks up an user's ban history.
@@ -69,19 +74,31 @@ namespace Content.Server.Database
/// </summary>
/// <param name="address">The ip address of the user.</param>
/// <param name="userId">The id of the user.</param>
/// <param name="hwId">The HWId of the user.</param>
/// <returns>The user's ban history.</returns>
Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, NetUserId? userId);
Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId);
Task AddServerBanAsync(ServerBanDef serverBan);
Task AddServerUnbanAsync(ServerUnbanDef serverBan);
// Player records
Task UpdatePlayerRecordAsync(NetUserId userId, string userName, IPAddress address);
Task UpdatePlayerRecordAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId);
Task<PlayerRecord?> GetPlayerRecordByUserName(string userName, CancellationToken cancel = default);
Task<PlayerRecord?> GetPlayerRecordByUserId(NetUserId userId, CancellationToken cancel = default);
// Connection log
Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address);
Task AddConnectionLogAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId);
// Admins
Task<Admin?> GetAdminDataForAsync(NetUserId userId, CancellationToken cancel = default);
@@ -179,14 +196,20 @@ namespace Content.Server.Database
return _db.GetServerBanAsync(id);
}
public Task<ServerBanDef?> GetServerBanAsync(IPAddress? address, NetUserId? userId)
public Task<ServerBanDef?> GetServerBanAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
return _db.GetServerBanAsync(address, userId);
return _db.GetServerBanAsync(address, userId, hwId);
}
public Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, NetUserId? userId)
public Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
return _db.GetServerBansAsync(address, userId);
return _db.GetServerBansAsync(address, userId, hwId);
}
public Task AddServerBanAsync(ServerBanDef serverBan)
@@ -199,9 +222,13 @@ namespace Content.Server.Database
return _db.AddServerUnbanAsync(serverUnban);
}
public Task UpdatePlayerRecordAsync(NetUserId userId, string userName, IPAddress address)
public Task UpdatePlayerRecordAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId)
{
return _db.UpdatePlayerRecord(userId, userName, address);
return _db.UpdatePlayerRecord(userId, userName, address, hwId);
}
public Task<PlayerRecord?> GetPlayerRecordByUserName(string userName, CancellationToken cancel = default)
@@ -214,9 +241,13 @@ namespace Content.Server.Database
return _db.GetPlayerRecordByUserId(userId, cancel);
}
public Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address)
public Task AddConnectionLogAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId)
{
return _db.AddConnectionLogAsync(userId, userName, address);
return _db.AddConnectionLogAsync(userId, userName, address, hwId);
}
public Task<Admin?> GetAdminDataForAsync(NetUserId userId, CancellationToken cancel = default)

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Data;
using System.Linq;
using System.Net;
@@ -48,7 +49,10 @@ namespace Content.Server.Database
return ConvertBan(ban);
}
public override async Task<ServerBanDef?> GetServerBanAsync(IPAddress? address, NetUserId? userId)
public override async Task<ServerBanDef?> GetServerBanAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
if (address == null && userId == null)
{
@@ -57,73 +61,31 @@ namespace Content.Server.Database
await using var db = await GetDbImpl();
var query = db.PgDbContext.Ban
.Include(p => p.Unban)
.Where(p => p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.Now));
if (userId is { } uid)
{
if (address == null)
{
// Only have a user ID.
query = query.Where(p => p.UserId == uid.UserId);
}
else
{
// Have both user ID and IP address.
query = query.Where(p =>
(p.Address != null && EF.Functions.ContainsOrEqual(p.Address.Value, address))
|| p.UserId == uid.UserId);
}
}
else
{
// Only have a connecting address.
query = query.Where(
p => p.Address != null && EF.Functions.ContainsOrEqual(p.Address.Value, address));
}
var query = MakeBanLookupQuery(address, userId, hwId, db)
.Where(p => p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.Now))
.OrderByDescending(b => b.BanTime);
var ban = await query.FirstOrDefaultAsync();
return ConvertBan(ban);
}
public override async Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, NetUserId? userId)
public override async Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
if (address == null && userId == null)
if (address == null && userId == null && hwId == null)
{
throw new ArgumentException("Address and userId cannot both be null");
}
await using var db = await GetDbImpl();
var query = db.PgDbContext.Ban
.Include(p => p.Unban).AsQueryable();
if (userId is { } uid)
{
if (address == null)
{
// Only have a user ID.
query = query.Where(p => p.UserId == uid.UserId);
}
else
{
// Have both user ID and IP address.
query = query.Where(p =>
(p.Address != null && EF.Functions.ContainsOrEqual(p.Address.Value, address))
|| p.UserId == uid.UserId);
}
}
else
{
// Only have a connecting address.
query = query.Where(
p => p.Address != null && EF.Functions.ContainsOrEqual(p.Address.Value, address));
}
var query = MakeBanLookupQuery(address, userId, hwId, db);
var queryBans = await query.ToArrayAsync();
var bans = new List<ServerBanDef>();
var bans = new List<ServerBanDef>(queryBans.Length);
foreach (var ban in queryBans)
{
@@ -138,6 +100,45 @@ namespace Content.Server.Database
return bans;
}
private static IQueryable<PostgresServerBan> MakeBanLookupQuery(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId,
DbGuardImpl db)
{
IQueryable<PostgresServerBan>? query = null;
if (userId is { } uid)
{
var newQ = db.PgDbContext.Ban
.Include(p => p.Unban)
.Where(b => b.UserId == uid.UserId);
query = query == null ? newQ : query.Union(newQ);
}
if (address != null)
{
var newQ = db.PgDbContext.Ban
.Include(p => p.Unban)
.Where(b => b.Address != null && EF.Functions.ContainsOrEqual(b.Address.Value, address));
query = query == null ? newQ : query.Union(newQ);
}
if (hwId != null)
{
var newQ = db.PgDbContext.Ban
.Include(p => p.Unban)
.Where(b => b.HWId!.SequenceEqual(hwId));
query = query == null ? newQ : query.Union(newQ);
}
query = query!.Distinct();
return query;
}
private static ServerBanDef? ConvertBan(PostgresServerBan? ban)
{
if (ban == null)
@@ -163,6 +164,7 @@ namespace Content.Server.Database
ban.Id,
uid,
ban.Address,
ban.HWId == null ? null : ImmutableArray.Create(ban.HWId),
ban.BanTime,
ban.ExpirationTime,
ban.Reason,
@@ -196,6 +198,7 @@ namespace Content.Server.Database
db.PgDbContext.Ban.Add(new PostgresServerBan
{
Address = serverBan.Address,
HWId = serverBan.HWId?.ToArray(),
Reason = serverBan.Reason,
BanningAdmin = serverBan.BanningAdmin?.UserId,
BanTime = serverBan.BanTime.UtcDateTime,
@@ -220,7 +223,11 @@ namespace Content.Server.Database
await db.PgDbContext.SaveChangesAsync();
}
public override async Task UpdatePlayerRecord(NetUserId userId, string userName, IPAddress address)
public override async Task UpdatePlayerRecord(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId)
{
await using var db = await GetDbImpl();
@@ -237,6 +244,7 @@ namespace Content.Server.Database
record.LastSeenTime = DateTime.UtcNow;
record.LastSeenAddress = address;
record.LastSeenUserName = userName;
record.LastSeenHWId = hwId.ToArray();
await db.PgDbContext.SaveChangesAsync();
}
@@ -277,10 +285,15 @@ namespace Content.Server.Database
new DateTimeOffset(record.FirstSeenTime),
record.LastSeenUserName,
new DateTimeOffset(record.LastSeenTime),
record.LastSeenAddress);
record.LastSeenAddress,
record.LastSeenHWId?.ToImmutableArray());
}
public override async Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address)
public override async Task AddConnectionLogAsync(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId)
{
await using var db = await GetDbImpl();
@@ -289,7 +302,8 @@ namespace Content.Server.Database
Address = address,
Time = DateTime.UtcNow,
UserId = userId.UserId,
UserName = userName
UserName = userName,
HWId = hwId.ToArray()
});
await db.PgDbContext.SaveChangesAsync();

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.Linq;
using System.Net;
@@ -58,7 +59,10 @@ namespace Content.Server.Database
return ConvertBan(ban);
}
public override async Task<ServerBanDef?> GetServerBanAsync(IPAddress? address, NetUserId? userId)
public override async Task<ServerBanDef?> GetServerBanAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
await using var db = await GetDbImpl();
@@ -69,23 +73,15 @@ namespace Content.Server.Database
.Where(p => p.Unban == null && (p.ExpirationTime == null || p.ExpirationTime.Value > DateTime.UtcNow))
.ToListAsync();
foreach (var ban in bans)
{
if (address != null && ban.Address != null && address.IsInSubnet(ban.Address))
{
return ConvertBan(ban);
}
if (userId is { } id && ban.UserId == id.UserId)
{
return ConvertBan(ban);
}
}
return null;
return bans.FirstOrDefault(b => BanMatches(b, address, userId, hwId)) is { } foundBan
? ConvertBan(foundBan)
: null;
}
public override async Task<List<ServerBanDef>> GetServerBansAsync(IPAddress? address, NetUserId? userId)
public override async Task<List<ServerBanDef>> GetServerBansAsync(
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
await using var db = await GetDbImpl();
@@ -95,30 +91,34 @@ namespace Content.Server.Database
.Include(p => p.Unban)
.ToListAsync();
var bans = new List<ServerBanDef>();
return queryBans
.Where(b => BanMatches(b, address, userId, hwId))
.Select(ConvertBan)
.ToList()!;
}
foreach (var ban in queryBans)
private static bool BanMatches(
SqliteServerBan ban,
IPAddress? address,
NetUserId? userId,
ImmutableArray<byte>? hwId)
{
if (address != null && ban.Address != null && address.IsInSubnet(ban.Address))
{
ServerBanDef? banDef = null;
if (address != null && ban.Address != null && address.IsInSubnet(ban.Address))
{
banDef = ConvertBan(ban);
}
else if (userId is { } id && ban.UserId == id.UserId)
{
banDef = ConvertBan(ban);
}
if (banDef == null)
{
continue;
}
bans.Add(banDef);
return true;
}
return bans;
if (userId is { } id && ban.UserId == id.UserId)
{
return true;
}
if (hwId is { } hwIdVar && hwIdVar.AsSpan().SequenceEqual(ban.HWId))
{
return true;
}
return false;
}
public override async Task AddServerBanAsync(ServerBanDef serverBan)
@@ -136,6 +136,7 @@ namespace Content.Server.Database
Address = addrStr,
Reason = serverBan.Reason,
BanningAdmin = serverBan.BanningAdmin?.UserId,
HWId = serverBan.HWId?.ToArray(),
BanTime = serverBan.BanTime.UtcDateTime,
ExpirationTime = serverBan.ExpirationTime?.UtcDateTime,
UserId = serverBan.UserId?.UserId
@@ -158,7 +159,11 @@ namespace Content.Server.Database
await db.SqliteDbContext.SaveChangesAsync();
}
public override async Task UpdatePlayerRecord(NetUserId userId, string userName, IPAddress address)
public override async Task UpdatePlayerRecord(
NetUserId userId,
string userName,
IPAddress address,
ImmutableArray<byte> hwId)
{
await using var db = await GetDbImpl();
@@ -175,6 +180,7 @@ namespace Content.Server.Database
record.LastSeenTime = DateTime.UtcNow;
record.LastSeenAddress = address.ToString();
record.LastSeenUserName = userName;
record.LastSeenHWId = hwId.ToArray();
await db.SqliteDbContext.SaveChangesAsync();
}
@@ -215,8 +221,10 @@ namespace Content.Server.Database
new DateTimeOffset(record.FirstSeenTime, TimeSpan.Zero),
record.LastSeenUserName,
new DateTimeOffset(record.LastSeenTime, TimeSpan.Zero),
IPAddress.Parse(record.LastSeenAddress));
IPAddress.Parse(record.LastSeenAddress),
record.LastSeenHWId?.ToImmutableArray());
}
private static ServerBanDef? ConvertBan(SqliteServerBan? ban)
{
if (ban == null)
@@ -225,13 +233,13 @@ namespace Content.Server.Database
}
NetUserId? uid = null;
if (ban.UserId is {} guid)
if (ban.UserId is { } guid)
{
uid = new NetUserId(guid);
}
NetUserId? aUid = null;
if (ban.BanningAdmin is {} aGuid)
if (ban.BanningAdmin is { } aGuid)
{
aUid = new NetUserId(aGuid);
}
@@ -250,6 +258,7 @@ namespace Content.Server.Database
ban.Id,
uid,
addrTuple,
ban.HWId == null ? null : ImmutableArray.Create(ban.HWId),
ban.BanTime,
ban.ExpirationTime,
ban.Reason,
@@ -265,7 +274,7 @@ namespace Content.Server.Database
}
NetUserId? aUid = null;
if (unban.UnbanningAdmin is {} aGuid)
if (unban.UnbanningAdmin is { } aGuid)
{
aUid = new NetUserId(aGuid);
}
@@ -276,7 +285,8 @@ namespace Content.Server.Database
unban.UnbanTime);
}
public override async Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address)
public override async Task AddConnectionLogAsync(NetUserId userId, string userName, IPAddress address,
ImmutableArray<byte> hwId)
{
await using var db = await GetDbImpl();
@@ -285,7 +295,8 @@ namespace Content.Server.Database
Address = address.ToString(),
Time = DateTime.UtcNow,
UserId = userId.UserId,
UserName = userName
UserName = userName,
HWId = hwId.ToArray()
});
await db.SqliteDbContext.SaveChangesAsync();