diff --git a/ShittyLINQ/DefaultIfEmpty.cs b/ShittyLINQ/DefaultIfEmpty.cs new file mode 100644 index 0000000..f7fac4d --- /dev/null +++ b/ShittyLINQ/DefaultIfEmpty.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; + +namespace ShittyLINQ +{ + public static partial class Extensions + { + public static IEnumerable DefaultIfEmpty(this IEnumerable source) + { + return DefaultIfEmpty(source, default(T)); + } + + public static IEnumerable DefaultIfEmpty(this IEnumerable self, T defaultValue) + { + if (self == null) throw new ArgumentNullException(); + if (self.Count() <= 0) + { + yield return defaultValue; + } + else + { + var iterator = self.GetEnumerator(); + while (iterator.MoveNext()) + { + yield return iterator.Current; + } + } + } + } +} diff --git a/ShittyLinqTests/DefaultIfEmptyTests.cs b/ShittyLinqTests/DefaultIfEmptyTests.cs new file mode 100644 index 0000000..039356c --- /dev/null +++ b/ShittyLinqTests/DefaultIfEmptyTests.cs @@ -0,0 +1,40 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using ShittyLINQ; +using Enumerable = System.Linq.Enumerable; + +namespace ShittyTests +{ + [TestClass] + public class DefaultIfEmptyTests + { + [TestMethod] + public void Default_NotEmpty() + { + int[] source = { 1, 2, 3 }; + int[] expectedResult = { 1, 2, 3 }; + + var result = source.DefaultIfEmpty(0); + Assert.IsTrue(Enumerable.SequenceEqual(expectedResult, result)); + } + + [TestMethod] + public void Default_Empty() + { + int[] source = { }; + int[] expectedResult = { 0 }; + + var result = source.DefaultIfEmpty(0); + Assert.IsTrue(Enumerable.SequenceEqual(expectedResult, result)); + } + + [TestMethod] + public void Default_EnumerableNull() + { + int[] source = null; + int[] expectedResult = { 0 }; + + var result = source.DefaultIfEmpty(0); + Assert.IsTrue(Enumerable.SequenceEqual(expectedResult, result)); + } + } +}