diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala index 888aad8..f767895 100644 --- a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvParser.scala @@ -1,17 +1,11 @@ package ceesvee.zio -import _root_.zio.Cause import _root_.zio.Chunk -import _root_.zio.NonEmptyChunk import _root_.zio.Ref -import _root_.zio.Scope import _root_.zio.Trace import _root_.zio.ZIO import _root_.zio.stream.ZPipeline -import _root_.zio.stream.ZSink -import _root_.zio.stream.ZStream import ceesvee.CsvParser -import ceesvee.CsvReader object ZioCsvParser { import CsvParser.Error @@ -20,65 +14,6 @@ object ZioCsvParser { import CsvParser.parseLine import CsvParser.splitStrings - /** - * Turns a stream of strings into a stream of CSV records extracting the first - * record. - */ - def parseWithHeader[R, E]( - stream: ZStream[R, E, String], - options: CsvReader.Options, - )(implicit - trace: Trace, - ): ZIO[Scope & R, Either[E, Error], (Chunk[String], ZStream[Any, Either[E, Error], Chunk[String]])] = { - stream.mapError(Left(_)).peel { - extractFirstLine(options).mapError(Right(_)) - }.map { case ((headers, state, records), s) => - (headers, ZStream.fromChunk(records) ++ (s >>> _parse(state, options).mapError(Right(_)))) - } - } - - private def extractFirstLine(options: CsvReader.Options)(implicit trace: Trace) = { - - val initial: Chunk[Chunk[String]] = Chunk.empty - - @SuppressWarnings(Array("org.wartremover.warts.IterableOps")) - def done(state: State, records: Chunk[Chunk[String]]) = { - NonEmptyChunk.fromChunk(records).map { rs => - Push.emit((rs.head, state, rs.tail), Chunk.empty) - } - } - - val push = Ref.make((State.initial, initial)).map { stateRef => (chunk: Option[Chunk[String]]) => - chunk match { - case None => stateRef.get.flatMap { case (state, lines) => - done(state, lines).getOrElse(Push.emit((Chunk.empty, state, lines), Chunk.empty)) - } - - case Some(strings) => - stateRef.get.flatMap { case (state, records) => - if (state.leftover.length > options.maximumLineLength) { - Push.fail(Error.LineTooLong(options.maximumLineLength), Chunk.empty) - } else { - val (newState, lines) = splitStrings(strings, state) - val moreRecords = lines.filter(str => !ignoreLine(str, options)).map(parseLine[Chunk](_, options)) - val _records = records ++ moreRecords - done(newState, _records).getOrElse(stateRef.set((newState, _records)) *> Push.more) - } - } - } - } - - ZSink.fromPush(push) - } - - private object Push { - val more: ZIO[Any, Nothing, Unit] = ZIO.unit - def emit[I, Z](z: Z, leftover: Chunk[I]): ZIO[Any, (Right[Nothing, Z], Chunk[I]), Nothing] = - ZIO.refailCause(Cause.fail((Right(z), leftover))) - def fail[I, E](e: E, leftover: Chunk[I]): ZIO[Any, (Left[E, Nothing], Chunk[I]), Nothing] = - ZIO.fail((Left(e), leftover)) - } - /** * Turns a stream of strings into a stream of CSV records. */ diff --git a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala index 5beaeca..908caeb 100644 --- a/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala +++ b/modules/zio/src/main/scala/ceesvee/zio/ZioCsvReader.scala @@ -1,9 +1,7 @@ package ceesvee.zio import _root_.zio.Cause -import _root_.zio.Scope import _root_.zio.Trace -import _root_.zio.ZIO import _root_.zio.stream.ZPipeline import _root_.zio.stream.ZStream import ceesvee.CsvHeader @@ -11,31 +9,54 @@ import ceesvee.CsvParser import ceesvee.CsvReader import ceesvee.CsvRecordDecoder +import scala.util.control.NoStackTrace + object ZioCsvReader { - import CsvParser.Error + + // replace with a union instead of redefining when removing Scala 2 support + sealed trait Error + object Error { + final case class LineTooLong(maximum: Int) + extends RuntimeException(s"CSV line exceeded maximum length of ${maximum.toString}") + with Error + + final case class MissingHeaders(missing: ::[String]) + extends RuntimeException(s"Missing headers: ${missing.mkString(", ")}") + with NoStackTrace + with Error + } /** * Turns a stream of strings into a stream of decoded CSV records. * * CSV lines are reordered based on the given headers. */ + @SuppressWarnings(Array("org.wartremover.warts.Null", "org.wartremover.warts.TryPartial", "org.wartremover.warts.Var")) def decodeWithHeader[R, E, T]( stream: ZStream[R, E, String], header: CsvHeader[T], options: CsvReader.Options, )(implicit trace: Trace, - ): ZIO[Scope & R, Either[Either[E, Error], CsvHeader.MissingHeaders], ZStream[R, Either[E, Error], Either[CsvHeader.Errors, T]]] = { - for { - tuple <- ZioCsvParser.parseWithHeader(stream, options).mapError(Left(_)) - (headerFields, s) = tuple - decoder <- header.create(headerFields) match { - case Left(error) => ZIO.refailCause(Cause.fail(error)).mapError(Right(_)) - case Right(decoder) => ZIO.succeed(decoder) + ): ZStream[R, Either[E, Error], Either[CsvHeader.Errors, T]] = ZStream.suspend { + var decoder: CsvHeader.Decoder[T] = null + + stream.mapError(Left(_)).via { + ZioCsvParser.parse(options).mapError { + case CsvParser.Error.LineTooLong(maximum) => Right(Error.LineTooLong(maximum)) + } + }.map { fields => + if (decoder eq null) { + decoder = header.create(fields).left.map { + case CsvHeader.MissingHeaders(missing) => Error.MissingHeaders(missing) + }.toTry.get + null + } else { + decoder.decode(fields) } - } yield { - s.map(decoder.decode(_)) - } + }.filter(_ ne null).catchSomeCause { + case Cause.Die(e: Error, _) => ZStream.fail(Right(e)) + } ++ (if (decoder eq null) ZStream.fail(Right(Error.MissingHeaders(header.headers))) else ZStream.empty) } /** @@ -46,7 +67,7 @@ object ZioCsvReader { )(implicit D: CsvRecordDecoder[T], trace: Trace, - ): ZPipeline[Any, Error, String, Either[CsvRecordDecoder.Errors, T]] = { + ): ZPipeline[Any, CsvParser.Error, String, Either[CsvRecordDecoder.Errors, T]] = { ZioCsvParser.parse(options) >>> ZPipeline.map(D.decode(_)) } } diff --git a/modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala similarity index 77% rename from modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala rename to modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala index 63794ce..0dddc4a 100644 --- a/modules/zio/src/test/scala/ceesvee/zio/CsvParserSpec.scala +++ b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvParserSpec.scala @@ -4,9 +4,9 @@ import ceesvee.CsvParser import zio.stream.ZStream import zio.test.ZIOSpecDefault -object CsvParserSpec extends ZIOSpecDefault with ceesvee.CsvParserParserSuite { +object ZioCsvParserSpec extends ZIOSpecDefault with ceesvee.CsvParserParserSuite { - override val spec = suite("CsvParser")( + override val spec = suite("ZioCsvParser")( parserSuite, ) diff --git a/modules/zio/src/test/scala/ceesvee/zio/ZioCsvReaderSpec.scala b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvReaderSpec.scala new file mode 100644 index 0000000..aced014 --- /dev/null +++ b/modules/zio/src/test/scala/ceesvee/zio/ZioCsvReaderSpec.scala @@ -0,0 +1,66 @@ +package ceesvee.zio + +import ceesvee.CsvHeader +import ceesvee.CsvReader +import ceesvee.CsvRecordDecoder +import zio.Chunk +import zio.stream.ZStream +import zio.test.ZIOSpecDefault +import zio.test.assertTrue + +object ZioCsvReaderSpec extends ZIOSpecDefault { + + private val options = CsvReader.Options.Defaults + + private val decodeWithHeaderSuite = suite("decode with header")( + test("no rows") { + ZioCsvReader.decodeWithHeader(ZStream.empty, Test.header, options).runCollect.either.map { result => + assertTrue(result == Left(Right(ZioCsvReader.Error.MissingHeaders(::("a", List("b", "c")))))) + } + }, + test("only header row") { + val stream = ZStream.succeed("a,b,c") + ZioCsvReader.decodeWithHeader(stream, Test.header, options).runCollect.either.map { result => + assertTrue(result == Right(Chunk.empty)) + } + }, + test("invalid header row") { + val stream = ZStream.succeed("a,b,d") + ZioCsvReader.decodeWithHeader(stream, Test.header, options).runCollect.either.map { result => + assertTrue(result == Left(Right(ZioCsvReader.Error.MissingHeaders(::("c", Nil))))) + } + }, + test("valid") { + val stream = ZStream.succeed("a,b,c\ns,1,true") + ZioCsvReader.decodeWithHeader(stream, Test.header, options).runCollect.map { result => + assertTrue(result == Chunk(Right(Test("s", 1, true)))) + } + }, + test("can be run multiple times") { + val stream = ZStream.succeed("a,b,c\ns,1,true") + val decode = ZioCsvReader.decodeWithHeader(stream, Test.header, options) + for { + result1 <- decode.runCollect + result2 <- decode.runCollect + } yield { + assertTrue(result1 == Chunk(Right(Test("s", 1, true)))) && + assertTrue(result1 == result2) + } + }, + ) + + override val spec = suite("ZioCsvReader")( + decodeWithHeaderSuite, + ) + + case class Test( + a: String, + b: Int, + c: Boolean, + ) + object Test { + implicit val decoder: CsvRecordDecoder[Test] = CsvRecordDecoder.derived + + val header: CsvHeader[Test] = CsvHeader.create(::("a", List("b", "c")))(decoder) + } +} diff --git a/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala b/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala index 709569b..75acc88 100644 --- a/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala +++ b/tests/src/test/scala/ceesvee/tests/RealWorldCsvSpec.scala @@ -69,13 +69,11 @@ object RealWorldCsvSpec extends ZIOSpecDefault { }, test("zio") { val stream = readFileZio(path) - ZIO.scoped[Any] { - ZioCsvReader.decodeWithHeader(stream, UkCausewayCoast.csvHeader, options).flatMap { s => - s.runCollect.mapError(Left(_)) + ZioCsvReader.decodeWithHeader(stream, UkCausewayCoast.csvHeader, options) + .runCollect + .map { result => + assertResult(result) } - }.map { result => - assertResult(result) - } }, ) }*), @@ -139,13 +137,12 @@ object RealWorldCsvSpec extends ZIOSpecDefault { }, test("zio") { val stream = readFileZio(path) - ZIO.scoped[Any] { - ZioCsvReader.decodeWithHeader(stream, header, options).flatMap { s => - s.collectRight.runCount.mapError(Left(_)) + ZioCsvReader.decodeWithHeader(stream, header, options) + .collectRight + .runCount + .map { count => + assertTrue(count == total) } - }.map { count => - assertTrue(count == total) - } }, ) }