Skip to content

Commit 7f76986

Browse files
committed
case insensitive header names in REST framework
1 parent 5660c9b commit 7f76986

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

commons-core/src/main/scala/com/avsystem/commons/meta/Mapping.scala

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.avsystem.commons
22
package meta
33

4+
import java.util.Locale
5+
46
import com.avsystem.commons.meta.Mapping.{ConcatIterable, KeyFilteredIterable}
57
import com.avsystem.commons.serialization.GenCodec
68

@@ -11,20 +13,25 @@ import scala.collection.{MapLike, mutable}
1113
* Simple immutable structure to collect named values while retaining their order and
1214
* providing fast, hashed lookup by name when necessary.
1315
* Intended to be used for [[multi]] raw parameters.
16+
* When `caseInsensitive = true`, fetching values by name will be case-insensitive, i.e. keys in internal
17+
* hashmap and those passed to `contains`, `isDefinedAt`, `apply` and `applyOrElse` will be lowercased.
1418
*/
15-
final class Mapping[+V](private val wrapped: IIterable[(String, V)])
19+
final class Mapping[+V](private val wrapped: IIterable[(String, V)], caseInsensitive: Boolean = false)
1620
extends IMap[String, V] with MapLike[String, V, Mapping[V]] {
1721

22+
private def normKey(key: String): String =
23+
if (caseInsensitive) key.toLowerCase(Locale.ENGLISH) else key
24+
1825
private[this] lazy val vector = {
1926
val keys = new mutable.HashSet[String]
20-
wrapped.iterator.filter({ case (k, _) => keys.add(k) }).toVector
27+
wrapped.iterator.filter({ case (k, _) => keys.add(normKey(k)) }).toVector
2128
}
2229
private[this] lazy val map =
23-
wrapped.toMap
30+
wrapped.iterator.map({ case (k, v) => (normKey(k), v) }).toMap
2431

2532
override def empty: Mapping[V] = Mapping.empty
2633
override protected[this] def newBuilder: mutable.Builder[(String, V), Mapping[V]] =
27-
Mapping.newBuilder[V]
34+
Mapping.newBuilder[V]()
2835

2936
override def size: Int =
3037
vector.size
@@ -33,17 +40,18 @@ final class Mapping[+V](private val wrapped: IIterable[(String, V)])
3340
override def valuesIterator: Iterator[V] =
3441
vector.iterator.map({ case (_, v) => v })
3542
override def keys: Iterable[String] =
36-
map.keys
43+
vector.map({ case (k, _) => k })
3744
override def keySet: ISet[String] =
38-
map.keySet
45+
vector.iterator.map({ case (k, _) => k }).toSet
3946
override def contains(key: String): Boolean =
40-
map.contains(key)
47+
map.contains(normKey(key))
4148
override def isDefinedAt(key: String): Boolean =
42-
map.isDefinedAt(key)
49+
map.isDefinedAt(normKey(key))
4350
override def applyOrElse[A1 <: String, B1 >: V](key: A1, default: A1 => B1): B1 =
44-
map.applyOrElse(key, default)
51+
if (!caseInsensitive) map.applyOrElse(key, default)
52+
else map.applyOrElse(normKey(key), (_: String) => default(key))
4553
override def apply(key: String): V =
46-
map.apply(key)
54+
map.apply(normKey(key))
4755

4856
def get(key: String): Option[V] = map.get(key)
4957
def -(key: String): Mapping[V] = Mapping(new KeyFilteredIterable(wrapped, _ == key))
@@ -65,10 +73,11 @@ final class Mapping[+V](private val wrapped: IIterable[(String, V)])
6573
object Mapping {
6674
def empty[V]: Mapping[V] = new Mapping(Nil)
6775
def apply[V](pairs: (String, V)*): Mapping[V] = new Mapping(pairs.toList)
68-
def apply[V](pairs: IIterable[(String, V)]): Mapping[V] = new Mapping(pairs)
76+
def apply[V](pairs: IIterable[(String, V)], caseInsensitive: Boolean = false): Mapping[V] =
77+
new Mapping(pairs, caseInsensitive)
6978

70-
def newBuilder[V]: mutable.Builder[(String, V), Mapping[V]] =
71-
new MListBuffer[(String, V)].mapResult(new Mapping(_))
79+
def newBuilder[V](caseInsensitive: Boolean = false): mutable.Builder[(String, V), Mapping[V]] =
80+
new MListBuffer[(String, V)].mapResult(new Mapping(_, caseInsensitive))
7281

7382
private class ConcatIterable[+V](first: IIterable[V], second: IIterable[V]) extends IIterable[V] {
7483
def iterator: Iterator[V] = first.iterator ++ second.iterator
@@ -82,8 +91,8 @@ object Mapping {
8291
}
8392

8493
private val reusableCBF = new CanBuildFrom[Nothing, (String, Any), Mapping[Any]] {
85-
def apply(from: Nothing): mutable.Builder[(String, Any), Mapping[Any]] = newBuilder[Any]
86-
def apply(): mutable.Builder[(String, Any), Mapping[Any]] = newBuilder[Any]
94+
def apply(from: Nothing): mutable.Builder[(String, Any), Mapping[Any]] = newBuilder[Any]()
95+
def apply(): mutable.Builder[(String, Any), Mapping[Any]] = newBuilder[Any]()
8796
}
8897

8998
implicit def canBuildFrom[V]: CanBuildFrom[Nothing, (String, V), Mapping[V]] =

commons-core/src/main/scala/com/avsystem/commons/rest/data.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ object QueryValue {
4949
}.mkString(FormKVPairSep)
5050

5151
def decode(queryString: String): Mapping[QueryValue] = {
52-
val builder = Mapping.newBuilder[QueryValue]
52+
val builder = Mapping.newBuilder[QueryValue]()
5353
queryString.split(FormKVPairSep).iterator.filter(_.nonEmpty).map(_.split(FormKVSep, 2)).foreach {
5454
case Array(name, value) => builder += UrlEncoding.decode(name) -> QueryValue(UrlEncoding.decode(value))
5555
case _ => throw new IllegalArgumentException(s"invalid query string $queryString")
@@ -166,7 +166,7 @@ object HttpBody {
166166
case HttpBody.Empty => Mapping.empty
167167
case _ =>
168168
val oi = new JsonStringInput(new JsonReader(body.readJson().value)).readObject()
169-
val builder = Mapping.newBuilder[JsonValue]
169+
val builder = Mapping.newBuilder[JsonValue]()
170170
while (oi.hasNext) {
171171
val fi = oi.nextField()
172172
builder += ((fi.fieldName, JsonValue(fi.readRawJson())))

commons-core/src/test/scala/com/avsystem/commons/rest/RawRestTest.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ trait UserApi {
3535
def autopost(bodyarg: String): Future[String]
3636
def singleBodyAutopost(@Body body: String): Future[String]
3737
@FormBody def formpost(@Query qarg: String, sarg: String, iarg: Int): Future[String]
38+
39+
def eatHeader(@Header("X-Stuff") stuff: String): Future[String]
3840
}
3941
object UserApi extends DefaultRestApiCompanion[UserApi]
4042

@@ -77,6 +79,7 @@ class RawRestTest extends FunSuite with ScalaFutures {
7779
def formpost(qarg: String, sarg: String, iarg: Int): Future[String] = Future.successful(s"$qarg-$sarg-$iarg")
7880
def fail: Future[Unit] = Future.failed(HttpErrorException(400, "zuo"))
7981
def failMore: Future[Unit] = throw HttpErrorException(400, "ZUO")
82+
def eatHeader(stuff: String): Future[String] = Future.successful(stuff.toLowerCase)
8083
}
8184

8285
var trafficLog: String = _
@@ -198,4 +201,16 @@ class RawRestTest extends FunSuite with ScalaFutures {
198201
serverHandle(request).apply(promise.complete)
199202
assert(promise.future.futureValue == response)
200203
}
204+
205+
test("header case insensitivity") {
206+
val params = RestParameters(
207+
path = List(PathValue("eatHeader")),
208+
headers = Mapping(List("x-sTuFf" -> HeaderValue("StUfF")), caseInsensitive = true)
209+
)
210+
val request = RestRequest(HttpMethod.POST, params, HttpBody.Empty)
211+
val response = RestResponse(200, Mapping.empty, HttpBody.json(JsonValue("\"stuff\"")))
212+
val promise = Promise[RestResponse]
213+
serverHandle(request).apply(promise.complete)
214+
assert(promise.future.futureValue == response)
215+
}
201216
}

commons-jetty/src/main/scala/com/avsystem/commons/jetty/rest/RestServlet.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ object RestServlet {
3030
val pathPrefix = request.getContextPath.orEmpty + request.getServletPath.orEmpty
3131
val path = PathValue.splitDecode(request.getRequestURI.stripPrefix(pathPrefix))
3232
val query = request.getQueryString.opt.map(QueryValue.decode).getOrElse(Mapping.empty)
33-
val headersBuilder = Mapping.newBuilder[HeaderValue]
33+
val headersBuilder = Mapping.newBuilder[HeaderValue](caseInsensitive = true)
3434
request.getHeaderNames.asScala.foreach { headerName =>
3535
headersBuilder += headerName -> HeaderValue(request.getHeader(headerName))
3636
}

0 commit comments

Comments
 (0)