diff --git a/modules/core/src/main/scala/ceesvee/CsvParser.scala b/modules/core/src/main/scala/ceesvee/CsvParser.scala index 0ac1f39..90a6ac9 100644 --- a/modules/core/src/main/scala/ceesvee/CsvParser.scala +++ b/modules/core/src/main/scala/ceesvee/CsvParser.scala @@ -139,12 +139,14 @@ object CsvParser { case class State( leftover: String, - insideQuote: Boolean, + insideQuoteIndex: Int, + previousCarriageReturn: Boolean, ) object State { val initial: State = State( leftover = "", - insideQuote = false, + insideQuoteIndex = -1, + previousCarriageReturn = false, ) } @@ -158,7 +160,8 @@ object CsvParser { state: State, )(implicit f: Factory[String, C[String]]): (State, C[String]) = { val builder = f.newBuilder - var insideQuote = state.insideQuote + var insideQuoteIndex = state.insideQuoteIndex + var previousCarriageReturn = state.previousCarriageReturn var leftover = state.leftover val it = strings.iterator @@ -168,40 +171,60 @@ object CsvParser { val concat = leftover.concat(string) - // assume we have already processed `leftover`, - // reprocess the last character in case it was a '"' or '\r' - var i = (leftover.length - 1).max(0) + var insideQuote = false + var i = + if (insideQuoteIndex >= 0) insideQuoteIndex + else if (previousCarriageReturn) leftover.length - 1 + else leftover.length var sliceStart = 0 while (i < concat.length) { (concat(i): @switch) match { case '"' => - if (insideQuote && (i + 1) < concat.length && concat(i + 1) == '"') { // escaped quote - i += 2 + if (insideQuote) { + if ((i + 1) == concat.length) { // last char + i += 1 // not enough information + } else { + if (concat(i + 1) == '"') { // escaped quote + i += 2 + } else { + insideQuote = false + insideQuoteIndex = -1 + i += 1 + } + } } else { + insideQuote = true + insideQuoteIndex = i i += 1 - if (i < concat.length) { - insideQuote = !insideQuote - } } case '\n' => - if (!insideQuote) { - val _ = builder += concat.substring(sliceStart, i) + if (insideQuote) { i += 1 - sliceStart = i } else { + val sliceEnd = if (previousCarriageReturn) i - 1 else i + val _ = builder += concat.substring(sliceStart, sliceEnd) i += 1 + sliceStart = i } case '\r' => - if (!insideQuote && (i + 1) < concat.length && concat(i + 1) == '\n') { - val _ = builder += concat.substring(sliceStart, i) - i += 2 - sliceStart = i - } else { + if (insideQuote) { i += 1 + } else { + if ((i + 1) == concat.length) { // last char + i += 1 // previousCarriageReturn set later + } else { + if (concat(i + 1) == '\n') { + val _ = builder += concat.substring(sliceStart, i) + i += 2 + sliceStart = i + } else { + i += 1 + } + } } case _ => @@ -209,11 +232,13 @@ object CsvParser { } } + insideQuoteIndex = insideQuoteIndex - sliceStart + previousCarriageReturn = concat(i - 1) == '\r' leftover = concat.substring(sliceStart, concat.length) } } - (State(leftover, insideQuote = insideQuote), builder.result()) + (State(leftover, insideQuoteIndex = insideQuoteIndex, previousCarriageReturn = previousCarriageReturn), builder.result()) } /** diff --git a/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala b/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala index 82ba83d..39c0d76 100644 --- a/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala +++ b/modules/core/src/test/scala/ceesvee/CsvParserSpec.scala @@ -22,6 +22,40 @@ object CsvParserSpec extends ZIOSpecDefault with CsvParserParserSuite { assertTrue(lines == List("abc\rdef", "ghi", "jkl")) && assertTrue(state.leftover == "mno") }, + test("trailing double quotes") { + val strings = List( + "a,\"b\"", + ",c,\"d\"\"e\",\"", + "\"", + "\nfg\"", + ) + val (state, lines) = CsvParser.splitStrings(strings, CsvParser.State.initial) + val strings2 = List( + "\n\"\"\"", + "\n\"hi\"\"", + ) + val (state2, lines2) = CsvParser.splitStrings(strings2, state) + val strings3 = List( + "j\"", + "\nmno", + ) + val (state3, lines3) = CsvParser.splitStrings(strings3, state2) + assertTrue( + lines == List("""a,"b",c,"d""e","""""), + state.insideQuoteIndex == 2, + state.leftover == "fg\"", + ) && + assertTrue( + lines2 == List("fg\"\n\"\"\""), + state2.insideQuoteIndex == 0, + state2.leftover == "\"hi\"\"", + ) && + assertTrue( + lines3 == List("\"hi\"\"j\""), + state3.insideQuoteIndex == -9, + state3.leftover == "mno", + ) + }, test("quotes and new lines") { val strings = List( "a\"b\"c\n", diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala index 1dedd2a..888aad8 100644 --- a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala @@ -112,12 +112,12 @@ object ZioCsvParser { Ref.make(state).map { stateRef => (chunk: Option[Chunk[String]]) => chunk match { case None => - stateRef.getAndSet(State.initial).map { case State(leftover, _) => + stateRef.getAndSet(State.initial).map { case State(leftover, _, _) => if (leftover.isEmpty) Chunk.empty else Chunk(leftover) } case Some(strings) => - stateRef.get.flatMap { case State(leftover, _) => + stateRef.get.flatMap { case State(leftover, _, _) => ZIO.fail(Error.LineTooLong(options.maximumLineLength)) .when(leftover.length > options.maximumLineLength) } *> stateRef.modify(splitStrings(strings, _).swap)