Thursday, 10 June 2010

Taking the pain out of parameter validation

One of the biggest pains I find when writing API components is validating parameters. Now don’t get me wrong I don’t mind validating a parameter and failing quickly to ensure that your component works correctly it’s the tediousness of the code that bothers me. Take the following method signature for example.

public void SomeMethod(SomeObject someObject, int maxValue)
{
    if (someObject == null)
    {
        throw new ArgumentNullException("someObject", "Parameter 'someObject' cannot be null.");
    }
    if (! someObject.SupportsSomeFunction())
    {
        throw new ArgumentException("Some object does not support some function.", "someObject");
    }
    if (maxValue < 1)
    {
        throw new ArgumentOutOfRangeException("maxValue", maxValue, "max value must be greater than zero.");
    }
}

As you can see this is very tedious and laborious, there is potential duplication with other methods that will have similar error messages, the if statements create lots of noise and my pet hate is having to put in the parameter names as strings.

Existing solutions

One solution to this that I have seen is to use a static or extension method for the type of validation that you want to do such as:

My preferred implementation and one that is very similar to my own is the fluent interface. I prefer the fluent interface implementation not just because it is fluent but because it doesn’t use extension methods. That is not to say that extension methods don’t have their place but once you introduce an extension method that can be used on any object your intelisense will soon become cluttered with irrelevant methods plus the syntax of Validate.Argument is much clearer. Finally whilst they are both good at tackling the laboriousness of the above code they still require you to pass in the name of the parameter.

Getting the parameter name

In order to progress we first need to find a way of getting the name of the parameter from the method. One idea I have seen is to use an expression.

public void MyMethod(SomeObject someObject)
{
    ValidateArgumentIsNotNull(() => someObject);
}
 
public static void ValidateArgumentIsNotNull<T>(Expression<Func<T>> expr)
{
    // expression value != default of T
    if (!expr.Compile()().Equals(default(T))) return;
    var param = (MemberExpression) expr.Body;
    throw new ArgumentNullException(param.Member.Name);
}

Ignoring the fact that the code is buggy it does the job of pulling out the parameter name from the expression but there in lies the problem. Because we are using an expression we would need to compile the expression before we can get anything from it. If our validation is littered throughout our code and our constructors and methods are constantly being called what kind of performance hits will we see?

To get round the potential of performance hits we need to go lower down into the IL. Luckily for me Rinat Abdullin had already made a start.

static Exception ValidateArgumentIsNotNull<TParameter>(Func<TParameter> argument)
{
    var il = argument.Method.GetMethodBody().GetILAsByteArray();
    var fieldHandle = BitConverter.ToInt32(il,2);
    var field = argument.Target.GetType().Module.ResolveField(fieldHandle);
    return new ArgumentNullException(field.Name, string.Format("Parameter of type '{0}' can't be null", typeof (TParameter)););
}

This does the same job as the expression but is much faster, according to Rinat it is in the magnitude of 300 times faster. Unfortunately this code cannot be used in production as it does not handle  code built in release mode because the byte position of the parameter is different and it has trouble with generic types so I needed to take it one step further.

internal class FieldInfoReader<TParameter>
{
    private readonly Func<TParameter> arg;
 
    internal FieldInfoReader(Func<TParameter> arg)
    {
        this.arg = arg;
    }
 
    public FieldInfo GetFieldToken()
    {
        byte[] methodBodyIlByteArray = GetMethodBodyIlByteArray();
 
        int fieldToken = GetFieldToken(methodBodyIlByteArray);
 
        return GetFieldInfo(fieldToken);
    }
 
    private FieldInfo GetFieldInfo(int fieldToken)
    {
        FieldInfo fieldInfo = null;
 
        if (fieldToken > 0)
        {
            Type argType = arg.Target.GetType();
            Type[] genericTypeArguments = GetSubclassGenericTypes(argType);
            Type[] genericMethodArguments = arg.Method.GetGenericArguments();
 
            fieldInfo = argType.Module.ResolveField(fieldToken, genericTypeArguments, genericMethodArguments);
        }
 
        return fieldInfo;
    }
 
    private static OpCode GetOpCode(byte[] methodBodyIlByteArray, ref int currentPosition)
    {
        ushort value = methodBodyIlByteArray[currentPosition++];
 
        return value != 0xfe ? SingleByteOpCodes[value] : OpCodes.Nop;
    }
 
    private static int GetFieldToken(byte[] methodBodyIlByteArray)
    {
        int position = 0;
 
        while (position < methodBodyIlByteArray.Length)
        {
            OpCode code = GetOpCode(methodBodyIlByteArray, ref position);
 
            if (code.OperandType == OperandType.InlineField)
            {
                return ReadInt32(methodBodyIlByteArray, ref position);
            }
 
            position = MoveToNextPosition(position, code);
        }
 
        return 0;
    }
 
    private static int MoveToNextPosition(int position, OpCode code)
    {
        switch (code.OperandType)
        {
            case OperandType.InlineNone:
                break;
 
            case OperandType.InlineI8:
            case OperandType.InlineR:
                position += 8;
                break;
 
            case OperandType.InlineField:
            case OperandType.InlineBrTarget:
            case OperandType.InlineMethod:
            case OperandType.InlineSig:
            case OperandType.InlineTok:
            case OperandType.InlineType:
            case OperandType.InlineI:
            case OperandType.InlineString:
            case OperandType.InlineSwitch:
            case OperandType.ShortInlineR:
                position += 4;
                break;
 
            case OperandType.InlineVar:
                position += 2;
                break;
 
            case OperandType.ShortInlineBrTarget:
            case OperandType.ShortInlineI:
            case OperandType.ShortInlineVar:
                position++;
                break;
 
            default:
                throw new InvalidOperationException("Unknown operand type.");
        }
        return position;
    }
 
    private byte[] GetMethodBodyIlByteArray()
    {
        MethodBody methodBody = arg.Method.GetMethodBody();
 
        if (methodBody == null)
        {
            throw new InvalidOperationException();
        }
 
        return methodBody.GetILAsByteArray();
    }
 
    private static int ReadInt32(byte[] il, ref int position)
    {
        return ((il[position++] | (il[position++] << 8)) | (il[position++] << 0x10)) | (il[position++] << 0x18);
    }
 
    private static Type[] GetSubclassGenericTypes(Type toCheck)
    {
        var genericArgumentsTypes = new List<Type>();
 
        while (toCheck != null)
        {
            if (toCheck.IsGenericType)
            {
                genericArgumentsTypes.AddRange(toCheck.GetGenericArguments());
            }
 
            toCheck = toCheck.BaseType;
        }
 
        return genericArgumentsTypes.ToArray();
    }
 
    private static OpCode[] singleByteOpCodes;
 
    public static OpCode[] SingleByteOpCodes
    {
        get
        {
            if (singleByteOpCodes == null)
            {
                LoadOpCodes();
            }
            return singleByteOpCodes;
        }
    }
 
    private static void LoadOpCodes()
    {
        singleByteOpCodes = new OpCode[0x100];
 
        FieldInfo[] opcodeFieldInfos = typeof(OpCodes).GetFields();
 
        for (int i = 0; i < opcodeFieldInfos.Length; i++)
        {
            FieldInfo info1 = opcodeFieldInfos[i];
 
            if (info1.FieldType == typeof(OpCode))
            {
                var singleByteOpCode = (OpCode)info1.GetValue(null);
 
                var singleByteOpcodeIndex = (ushort)singleByteOpCode.Value;
 
                if (singleByteOpcodeIndex < 0x100)
                {
                    singleByteOpCodes[singleByteOpcodeIndex] = singleByteOpCode;
                }
            }
        }
    }
}

I cannot take full credit for the above code as it is structured on some code I found trawling the web which I have stripped down to do what I want. Apart from being overly complicated the FieldInfoReader parses the Func<> method bodies byte array looking for the correct position to find the parameters name and extracts it.

Plugging it into the fluent interface

Now we know how to get the parameter name we need to plug it all together. As I said before I prefer the static class approach with a fluent interface. The first step is to specify what it is we want to validate and make it clear to anyone reading the code what is under test.

[DebuggerStepThrough]
public static class Validate
{
    public static Argument<TParameter> Argument<TParameter>(Func<TParameter> arg)
    {
        if (arg == null)
        {
            throw new ArgumentNullException("arg");
        }
 
        var test = new FieldInfoReader<TParameter>(arg);
        
        FieldInfo fieldInfo = test.GetFieldToken();
 
        if (fieldInfo == null)
        {
            throw new ValidationException("No field info found in delegate");    
        }
 
        return new Argument<TParameter>(fieldInfo.Name, arg());
    }
}
 
[DebuggerStepThrough]
public class Argument<TParameterType>
{
    internal Argument(string parameterName, TParameterType parameter)
    {
        ParameterName = parameterName;
        ParameterValue = parameter;
    }
 
   internal string ParameterName { get; private set; }
 
    internal TParameterType ParameterValue { get; private set; }
}

The Validate.Argument method takes in a delegate to the parameter and extracts the parameter name and returns an argument class containing both the parameter name and the parameter value. The argument object is the key to the validation process. It is used in conjunction with extension methods for various types of validation to give us our fluent interface, an example of which is below.

[DebuggerStepThrough]
public static class ArgumentValidationExtensions
{
    public static ReferenceTypeArgument<TArgumentType> IsNotNull<TArgumentType>(this Argument<TArgumentType> argument) where TArgumentType : class
    {
        if (argumentParameterValue == null)
        {
            throw new ArgumentNullException(argument.ParameterName);
        }
 
        return new ReferenceTypeArgument<TArgumentType>(argument);
    }
 
    public static ReferenceTypeArgument<string> IsNotEmpty(this ReferenceTypeArgument<string> argument)
    {
        if (argument.ParameterValue.Length == 0)
        {
            throw new ArgumentException("Parameter cannot be an empty string.", argumentValidation.ParameterName);
        }
 
        return argument;
    }
}
 
public class ReferenceTypeArgument<TArgumentType> : Argument { }

The above extension method looks at an Argument that has a class type parameter and checks it for null, if it is null an ArgumentNullException otherwise a ReferenceArgument is returned which can then be used by the string specific validation method that only takes in a ReferenceArgument object meaning that we can now do the following.

public void SomeMethod(string value)
{
    Validate.Argument(() => value).IsNotNull().IsNotEmpty();
}

Putting on the icing

Using the above pattern the amount of validation methods you can use is only limited by your imagination but do you really want to create validation methods for obscure checks that are only going to be done in one or two places? What would be better is to provide methods for the most common of checks and provide a way that the consumer of the validation api can provide their own logic. What we want to do is to provide an interface where the consumer can provide a function to run against the parameter that indicates if the argument is valid.

public static Argument<TArgumentType> Satisifes<TArgumentType>(this Argument<TArgumentType> argument, Expression<Func<TArgumentType, bool>> expression)
{
    argument.ApplyValidation(
        expression.Compile(),
        () => string.Format("The parameter '{0}' failed the following validation '{1}'", argument.ParameterName, expression));
 
    return argument;
}
 
public static Argument<TArgumentType> Satisifes<TArgumentType>(this Argument<TArgumentType> argument, Func<TArgumentType, bool> function, string message)
{
    argument.ApplyValidation(function, () => message);
    
    return argument;
}
 
public static Argument<TArgumentType> Satisifes<TArgumentType>(this Argument<TArgumentType> argument, Func<TArgumentType, bool> function, string messageFormat, params object[] messageParameters)
{
    argument.ApplyValidation(function, () => string.Format(CultureInfo.InvariantCulture, messageFormat, messageParameters));
 
    return argument;
}
 
private static void ApplyValidation<TArgumentType>(this Argument<TArgumentType> argument, Func<TArgumentType, bool> testFunction, Func<string> messageFunction)
{
    if (!testFunction.Invoke(argument.ParameterValue))
    {
        throw new ArgumentException(messageFunction.Invoke(), argument.ParameterName);
    }
}

Here we provide two different ways a consumer can validate an argument. One takes in an expression the other takes in a function and a custom error message. The beauty of the expression is that it is self documenting, when you out put the expression x => x.CanDoSomething() that is what you get. So your error message in your argument exception will contain the expression. The following code would produce something like:

public void SomeMethod(Stream streamA, Stream streamB)
{
    Validate.Argument(() = streamA).Satisfies(stream => stream.CanRead());
    Validate.Argument(() = streamB).Satisfies(stream => stream.CanWrite(), "Cannot write to the stream.");
}
 
ArgumentException -> The parameter 'streamA' failed the following validation 'stream => stream.CanRead()'. Parameter 'streamA'
ArgumentException -> Cannot write to the stream. Parameter 'streamB'

Both of these are acceptable but you may prefer one over the other depending on what you were trying to convey.

Happy validating.

2 comments:

  1. Hi Matt,

    Have you heard of code contracts designed by Microsoft Research: http://research.microsoft.com/en-us/projects/contracts/

    This essentially allows static code analysis of parameters against a pre-determined contract and also supports runtime code analysis which makes the guard clauses redundent.

    You might also want to look at a simple Guard clause instead of a condition as this will make testing easier with better code coverage.

    Simon

    ReplyDelete
  2. Hi Simon,

    Yes I have heard of code contracts after the event and too late to introduce into our .net 3.5 projects. We may consider using it in any .net 4 projects.

    You may not have read my whole post but this API is in response to the repetitive and unreadable guard clause, that is exactly what you want to get away from. Public API methods are always littered with if null throw ArgumentNullException if that throw ArgumentException. It is much easier to have a static helper with with extension methods to help you do fluent validation.

    Surley it is much nicer to read:

    Validate.Argument(() => xyz)
    .IsNotNull()
    .Satisfies(param => param.ShouldHaveSomeValue);

    Than:

    if (xyz == null)
    {
    throw new ArgumentNullException("xyz", "Parameter xyz cannot be null");
    }

    if (xyz.ShouldHaveSomeValue)
    {
    throw new ArgumentException("Parameter xyz should have some value", "xyz");
    }

    With regard to unit testing and code coverage it is unaffected as the methods still throw ArgumentNull/Argument/ArgumentOutOfRangeExceptions.

    Matt

    ReplyDelete