diff --git a/Core/InteropServices/CorrespondingTypeAttribute.cs b/Core/InteropServices/CorrespondingTypeAttribute.cs
index b5380255..c731f902 100644
--- a/Core/InteropServices/CorrespondingTypeAttribute.cs
+++ b/Core/InteropServices/CorrespondingTypeAttribute.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Generic;
using System.Linq;
using Vanara.Extensions;
// ReSharper disable MemberCanBePrivate.Global
@@ -27,13 +28,13 @@ namespace Vanara.InteropServices
/// value to determine the type to get or set.
///
///
- [AttributeUsage(AttributeTargets.Field)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Class, AllowMultiple = true)]
public class CorrespondingTypeAttribute : Attribute
{
/// Initializes a new instance of the class.
/// The type that corresponds to this enumeration value.
/// The actions allowed for the type.
- public CorrespondingTypeAttribute(Type typeRef, CorrepsondingAction action = CorrepsondingAction.Get | CorrepsondingAction.Set)
+ public CorrespondingTypeAttribute(Type typeRef, CorrepsondingAction action = CorrepsondingAction.GetSet)
{
TypeRef = typeRef;
Action = action;
@@ -54,39 +55,69 @@ namespace Vanara.InteropServices
/// The type that corresponds to this enumeration value.
public Type TypeRef { get; }
- /// Determines whether this instance can get the type for the specified enum value.
- /// The enumeration value.
+ /// Determines whether this instance can get the type for the specified enum value or class.
+ /// The enumeration value or class instance.
/// The type supplied by the user to validate.
/// true if this instance can get the specified type; otherwise, false.
public static bool CanGet(object value, Type typeRef)
{
- var attr = GetAttrForEnum(value);
- return attr.Action.IsFlagSet(CorrepsondingAction.Get) && attr.TypeRef == typeRef;
+ return GetAttrForObj(value).Any(a => a.Action.IsFlagSet(CorrepsondingAction.Get) && a.TypeRef == typeRef);
}
- /// Determines whether this instance can set the type for the specified enum value.
- /// The enumeration value.
+ /// Determines whether this type can get the specified reference type.
+ /// The class type.
+ /// The type supplied by the user to validate.
+ /// true if this type can get the specified reference type; otherwise, false.
+ public static bool CanGet(Type type, Type typeRef)
+ {
+ return GetAttrForType(type).Any(a => a.Action.IsFlagSet(CorrepsondingAction.Get) && a.TypeRef == typeRef);
+ }
+
+ /// Determines whether this instance can set the type for the specified enum value or class.
+ /// The enumeration value or class instance.
/// The type supplied by the user to validate.
/// true if this instance can set the specified type; otherwise, false.
public static bool CanSet(object value, Type typeRef)
{
- var attr = GetAttrForEnum(value);
- return attr.Action.IsFlagSet(CorrepsondingAction.Set) && attr.TypeRef == typeRef;
+ return GetAttrForObj(value).Any(a => a.Action.IsFlagSet(CorrepsondingAction.Set) && a.TypeRef == typeRef);
}
- /// Gets the corresponding type for the supplied enumeration value.
- /// The enumeration value.
- /// The type defined by the attribute.
- public static Type GetCorrespondingType(object enumValue) => GetAttrForEnum(enumValue).TypeRef;
+ /// Determines whether this type can set the specified reference type.
+ /// The class type.
+ /// The type supplied by the user to validate.
+ /// true if this type can set the specified reference type; otherwise, false.
+ public static bool CanSet(Type type, Type typeRef)
+ {
+ return GetAttrForType(type).Any(a => a.Action.IsFlagSet(CorrepsondingAction.Set) && a.TypeRef == typeRef);
+ }
- private static CorrespondingTypeAttribute GetAttrForEnum(object value)
+ /// Gets the corresponding types for the supplied enumeration value.
+ /// The enumeration value or class.
+ /// The types defined by the attribute.
+ public static IEnumerable GetCorrespondingTypes(object enumValue) => GetAttrForObj(enumValue).Select(a => a.TypeRef);
+
+ /// Gets the corresponding types for the supplied enumeration value.
+ /// The enumeration value or class.
+ /// The types defined by the attribute.
+ public static IEnumerable GetCorrespondingTypes(Type type) => GetAttrForType(type).Select(a => a.TypeRef);
+
+ private static IEnumerable GetAttrForObj(object value)
{
if (value == null) throw new ArgumentNullException(nameof(value));
var valueType = value.GetType();
- if (!valueType.IsEnum) throw new ArgumentException("Value must be an enumeration value.", nameof(value));
- var attr = valueType.GetField(value.ToString()).GetCustomAttributes(typeof(CorrespondingTypeAttribute), false).Cast().FirstOrDefault();
- if (attr == null) throw new InvalidOperationException("Value must have the CorrespondingTypeAttribute defined.");
- if (attr.Action == CorrepsondingAction.Exception) throw new Exception();
+ if (!valueType.IsEnum && !valueType.IsClass) throw new ArgumentException("Value must be an enumeration or class value.", nameof(value));
+ var attr = (valueType.IsEnum ? valueType.GetField(value.ToString()).GetCustomAttributes(typeof(CorrespondingTypeAttribute), false) : valueType.GetCustomAttributes(typeof(CorrespondingTypeAttribute), false)).Cast().ToArray();
+ if (attr == null || attr.Length == 0) throw new InvalidOperationException("Value must have the CorrespondingTypeAttribute defined.");
+ if (attr.Any(a => a.Action == CorrepsondingAction.Exception)) throw new Exception();
+ return attr;
+ }
+
+ private static IEnumerable GetAttrForType(Type type)
+ {
+ if (type == null) throw new ArgumentNullException(nameof(type));
+ var attr = type.GetCustomAttributes(typeof(CorrespondingTypeAttribute), false).Cast().ToArray();
+ if (attr == null || attr.Length == 0) throw new InvalidOperationException("Type must have the CorrespondingTypeAttribute defined.");
+ if (attr.Any(a => a.Action == CorrepsondingAction.Exception)) throw new Exception();
return attr;
}
}
diff --git a/PInvoke/Security/AdvApi32/WinBase.cs b/PInvoke/Security/AdvApi32/WinBase.cs
index 23cfc18a..dce19923 100644
--- a/PInvoke/Security/AdvApi32/WinBase.cs
+++ b/PInvoke/Security/AdvApi32/WinBase.cs
@@ -1,5 +1,6 @@
using System;
using System.ComponentModel;
+using System.Linq;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Text;
@@ -926,7 +927,7 @@ namespace Vanara.PInvoke
///
public T GetInfo(TOKEN_INFORMATION_CLASS tokenInfoClass)
{
- if (CorrespondingTypeAttribute.GetCorrespondingType(tokenInfoClass) != typeof(T))
+ if (CorrespondingTypeAttribute.GetCorrespondingTypes(tokenInfoClass).FirstOrDefault() != typeof(T))
throw new InvalidCastException();
using (var pType = GetInfo(tokenInfoClass))
{
diff --git a/PInvoke/WinINet/WinINet.cs b/PInvoke/WinINet/WinINet.cs
index 8da3caa7..d086ee67 100644
--- a/PInvoke/WinINet/WinINet.cs
+++ b/PInvoke/WinINet/WinINet.cs
@@ -1,4 +1,5 @@
using System;
+using System.Linq;
using System.Runtime.InteropServices;
using Vanara.InteropServices;
using FILETIME = System.Runtime.InteropServices.ComTypes.FILETIME;
@@ -1399,7 +1400,7 @@ namespace Vanara.PInvoke
/// Internet option to be set. This can be one of the Option Flags values.
public static void InternetSetOption(this SafeInternetHandle hInternet, InternetOptionFlags option)
{
- if (CorrespondingTypeAttribute.GetCorrespondingType(option) != null) throw new ArgumentException($"{option} cannot be used to set options that do not require a value.");
+ if (CorrespondingTypeAttribute.GetCorrespondingTypes(option).FirstOrDefault() != null) throw new ArgumentException($"{option} cannot be used to set options that do not require a value.");
var res = InternetSetOption(hInternet, option, IntPtr.Zero, 0);
if (!res) Win32Error.ThrowLastError();
}