From 0c798f899fc81634c76f3e4632370b90b6dd81cd Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 12 Aug 2019 09:04:53 -0600 Subject: [PATCH] Completed unit testing and fixes for SafePSID and SafePSIDArray --- PInvoke/Security/AdvApi32/PSID.cs | 32 +++- UnitTests/PInvoke/Security/AdvApi32/PSIDTests.cs | 181 ++++++++++++++++------- 2 files changed, 154 insertions(+), 59 deletions(-) diff --git a/PInvoke/Security/AdvApi32/PSID.cs b/PInvoke/Security/AdvApi32/PSID.cs index 1d67ee75..2b60092a 100644 --- a/PInvoke/Security/AdvApi32/PSID.cs +++ b/PInvoke/Security/AdvApi32/PSID.cs @@ -63,6 +63,10 @@ namespace Vanara.PInvoke /// true if this instance is a valid SID; otherwise, false. public bool IsValidSid => IsValidSid(this); + /// Gets the length, in bytes, of the SID. + /// The SID length, in bytes. + public int Length => IsValidSid ? GetLengthSid(this) : 0; + /// Copies the specified SID from a memory pointer to a instance. /// The SID pointer. This value remains the responsibility of the caller to release. /// A instance. @@ -151,7 +155,7 @@ namespace Vanara.PInvoke /// Indicates whether the current object is equal to another object of the same type. /// An object to compare with this object. /// true if the current object is equal to the parameter; otherwise, false. - public bool Equals(SafePSID other) => EqualSid(this, other); + public bool Equals(SafePSID other) => other != null && (ReferenceEquals(this, other) || EqualSid(this, other)); /// Indicates whether the current object is equal to another object of the same type. /// An object to compare with this object. @@ -214,14 +218,22 @@ namespace Vanara.PInvoke /// Initializes a new instance of the class and assigns an existing handle. /// An object that represents the pre-existing handle to use. + /// The count of PSID array values pointed to by . /// - /// to reliably release the handle during the finalization phase; otherwise, (not recommended). + /// to reliably release the handle during the finalization phase; otherwise, (not + /// recommended). If , the individually allocated values for each PSID will also be released. /// - public SafePSIDArray(IntPtr preexistingHandle, bool ownsHandle = true) : base(preexistingHandle, ownsHandle) { } + public SafePSIDArray(IntPtr preexistingHandle, int count, bool ownsHandle = true) : base(preexistingHandle, ownsHandle) + { + if (ownsHandle) + Count = count; + else + items = new List(handle.ToIEnum(count).Select(p => new SafePSID(p))); + } /// Initializes a new instance of the class. /// A list of instances. - public SafePSIDArray(IEnumerable pSIDs) : this(pSIDs.Select(p => (PSID)p)) + public SafePSIDArray(IEnumerable pSIDs) : this(pSIDs?.Select(p => (PSID)p)) { } @@ -229,8 +241,9 @@ namespace Vanara.PInvoke /// A list of instances. public SafePSIDArray(IEnumerable pSIDs) : base() { + if (pSIDs is null) throw new ArgumentNullException(nameof(pSIDs)); items = pSIDs.Select(p => new SafePSID(p)).ToList(); - SetHandle(items.Cast().MarshalToPtr(i => LocalAlloc(LMEM.LPTR, i).DangerousGetHandle(), out _)); + SetHandle(items.Select(p => (IntPtr)p).MarshalToPtr(i => LocalAlloc(LMEM.LPTR, i).DangerousGetHandle(), out _)); } /// Initializes a new instance of the class. @@ -243,8 +256,13 @@ namespace Vanara.PInvoke get => items?.Count ?? throw new InvalidOperationException("The length must be set before using this function."); set { - if (items != null) throw new InvalidOperationException("The length can only be set once."); - items = new List(handle.ToIEnum(value).Select(p => new SafePSID(p))); + if (items != null) throw new InvalidOperationException("The length can only be set for partially initialized arrays."); + items = new List(); + foreach (var psid in handle.ToIEnum(value)) + { + items.Add(new SafePSID(psid)); + LocalFree(psid); + } } } diff --git a/UnitTests/PInvoke/Security/AdvApi32/PSIDTests.cs b/UnitTests/PInvoke/Security/AdvApi32/PSIDTests.cs index fb1c738f..296eb04f 100644 --- a/UnitTests/PInvoke/Security/AdvApi32/PSIDTests.cs +++ b/UnitTests/PInvoke/Security/AdvApi32/PSIDTests.cs @@ -1,15 +1,88 @@ using NUnit.Framework; using System; using System.Linq; +using System.Runtime.InteropServices; using System.Security.Principal; -using Vanara.InteropServices; +using Vanara.Extensions; using static Vanara.PInvoke.AdvApi32; namespace Vanara.PInvoke.Tests { + public static class UtilExt + { + public static byte[] GetBytes(this SecurityIdentifier si) + { + if (si == null) return new byte[0]; + var sidLen = si.BinaryLength; + var bytes = new byte[sidLen]; + si.GetBinaryForm(bytes, 0); + return bytes; + } + } + [TestFixture()] public class PSIDTests { + public static SafePSID GetCurrentSid() => new SafePSID(WindowsIdentity.GetCurrent().User.GetBytes()); + + [Test()] + public void CloneTest() + { + var sid = GetCurrentSid(); + var sid2 = sid.Clone(); + Assert.That(sid2.IsValidSid); + Assert.That(sid, Is.EqualTo(sid2)); + } + + [Test()] + public void CopyTest() + { + var sid = GetCurrentSid(); + Assert.That(!sid.IsInvalid); + Assert.That(sid.IsValidSid); + Assert.That(sid.ToString(), Does.StartWith("S-1-5")); + } + + [Test] + public void EqualsTest() + { + var ssid = new SafePSID("S-1-1-0"); + var esid = SafePSID.Everyone; + var mesid = SafePSID.Current; + Assert.That(ssid == esid, Is.True); + Assert.That(ssid != mesid, Is.True); + Assert.That(ssid.Equals(null), Is.False); + Assert.That(ssid == null, Is.False); + Assert.That(ssid.Equals((PSID)esid), Is.True); + Assert.That(ssid.Equals((IntPtr)esid), Is.True); + Assert.That(ssid.Equals((object)esid), Is.True); + Assert.That(ssid.Equals((object)(PSID)esid), Is.True); + Assert.That(ssid.Equals((object)(IntPtr)esid), Is.True); + Assert.That(ssid.Equals((object)54), Is.False); + } + + [Test()] + public void GetBinaryForm() + { + var sid = new SafePSID("S-1-1-0"); + Assert.That(sid.GetBinaryForm(), Is.EquivalentTo(new SecurityIdentifier(WellKnownSidType.WorldSid, null).GetBytes())); + } + + [Test()] + public void InitTest() + { + var sid = GetCurrentSid(); + var sidStr = sid.ToString(); + Assert.That(sidStr, Does.StartWith("S-1-5-")); + var ssid = sid.ToString().Substring(6).Split('-').Select(int.Parse).ToArray(); + var i = ssid[0]; + var dest = new int[ssid.Length - 1]; + Array.Copy(ssid, 1, dest, 0, ssid.Length - 1); + var sid2 = SafePSID.Init(KnownSIDAuthority.SECURITY_NT_AUTHORITY, i, dest); + Assert.That(sid2.IsValidSid); + Assert.That(sid, Is.EqualTo(sid2)); + } + [Test()] public void PSIDTest() { @@ -38,65 +111,69 @@ namespace Vanara.PInvoke.Tests Assert.That(sid.Equals(sid3), Is.False); } - [Test()] - public void CopyTest() - { - var sid = GetCurrentSid(); - Assert.That(!sid.IsInvalid); - Assert.That(sid.IsValidSid); - Assert.That(sid.ToString(), Does.StartWith("S-1-5")); - } - - public static SafePSID GetCurrentSid() => new SafePSID(WindowsIdentity.GetCurrent().User.GetBytes()); - - [Test()] - public void InitTest() - { - var sid = GetCurrentSid(); - var sidStr = sid.ToString(); - Assert.That(sidStr, Does.StartWith("S-1-5-")); - var ssid = sid.ToString().Substring(6).Split('-').Select(int.Parse).ToArray(); - var i = ssid[0]; - var dest = new int[ssid.Length - 1]; - Array.Copy(ssid, 1, dest, 0, ssid.Length - 1); - var sid2 = SafePSID.Init(KnownSIDAuthority.SECURITY_NT_AUTHORITY, i, dest); - Assert.That(sid2.IsValidSid); - Assert.That(sid, Is.EqualTo(sid2)); - } - - [Test()] - public void CloneTest() - { - var sid = GetCurrentSid(); - var sid2 = sid.Clone(); - Assert.That(sid2.IsValidSid); - Assert.That(sid, Is.EqualTo(sid2)); - } - - [Test()] - public void GetBinaryForm() - { - var sid = new SafePSID("S-1-1-0"); - Assert.That(sid.GetBinaryForm(), Is.EquivalentTo(new SecurityIdentifier(WellKnownSidType.WorldSid, null).GetBytes())); - } - [Test()] public void ToStringTest() { var sid = SafePSID.Init(KnownSIDAuthority.SECURITY_WORLD_SID_AUTHORITY, KnownSIDRelativeID.SECURITY_WORLD_RID); Assert.That(sid.ToString(), Is.EqualTo("S-1-1-0")); } - } - public static class UtilExt - { - public static byte[] GetBytes(this SecurityIdentifier si) + [Test] + public void SafePSIDArrayCtorTest() { - if (si == null) return new byte[0]; - var sidLen = si.BinaryLength; - var bytes = new byte[sidLen]; - si.GetBinaryForm(bytes, 0); - return bytes; + var sids = new[] { SafePSID.Current, SafePSID.Everyone }; + SafePSIDArray safeArr = null; + Assert.That(() => safeArr = new SafePSIDArray((SafePSID[])null), Throws.ArgumentNullException); + Assert.That(() => safeArr = new SafePSIDArray(new SafePSID[0]), Throws.Nothing); + Assert.That(safeArr.Count, Is.Zero); + Assert.That(() => safeArr = new SafePSIDArray(sids), Throws.Nothing); + Assert.That(safeArr.Count, Is.EqualTo(sids.Length)); + Assert.That(() => safeArr = new SafePSIDArray(Array.ConvertAll(sids, s => (PSID)s)), Throws.Nothing); + Assert.That(safeArr.Count, Is.EqualTo(sids.Length)); + Assert.That(EqualSid(safeArr[0], SafePSID.Current), Is.True); + Assert.That(EqualSid(safeArr[1], SafePSID.Everyone), Is.True); + Assert.That(() => safeArr[2], Throws.Exception); + Assert.That(safeArr, Is.EquivalentTo(sids)); + } + + [Test] + public void SafePSIDArrayCtorTest2() + { + // Build in-memory SID array + var sids = new[] { SafePSID.Current, SafePSID.Everyone }; + + SafePSIDArray safeArr = null; + Assert.That(() => safeArr = new SafePSIDArray(IntPtr.Zero, 0), Throws.Nothing); + Assert.That(safeArr.Count, Is.Zero); + + // Unowned + var ptr = Build(); + Assert.That(() => safeArr = new SafePSIDArray(ptr, sids.Length, false), Throws.Nothing); + Assert.That(safeArr.Count, Is.EqualTo(sids.Length)); + foreach (var psid in ptr.ToIEnum(sids.Length)) + Kernel32.LocalFree(psid); + Kernel32.LocalFree(ptr); + safeArr.Dispose(); + + // Owned + ptr = Build(); + Assert.That(() => safeArr = new SafePSIDArray(ptr, sids.Length, true), Throws.Nothing); + Assert.That(safeArr.Count, Is.EqualTo(sids.Length)); + safeArr.Dispose(); + + IntPtr Build() + { + var len = sids.Length * IntPtr.Size + sids.Sum(p => p.Length); + var mem = Kernel32.LocalAlloc(Kernel32.LMEM.LPTR, sids.Length * IntPtr.Size); + for (var i = 0; i < sids.Length; i++) + { + var sid = sids[i]; + var psid = Kernel32.LocalAlloc(Kernel32.LMEM.LPTR, sid.Length); + Marshal.Copy(sid.GetBinaryForm(), 0, (IntPtr)psid, sid.Length); + Marshal.WriteIntPtr((IntPtr)mem, i * IntPtr.Size, (IntPtr)psid); + } + return (IntPtr)mem; + } } } } \ No newline at end of file