Skip to content

Commit c8a3ca8

Browse files
Refactor SjsonnetMainBase to allow for more customization
1 parent f8bde9f commit c8a3ca8

File tree

1 file changed

+105
-105
lines changed

1 file changed

+105
-105
lines changed

sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala

Lines changed: 105 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package sjsonnet
22

3+
import upickle.core.SimpleVisitor
4+
35
import java.io.{
46
BufferedOutputStream,
57
InputStream,
@@ -14,53 +16,53 @@ import scala.annotation.unused
1416
import scala.util.Try
1517

1618
object SjsonnetMainBase {
17-
def resolveImport(
19+
class SimpleImporter(
1820
searchRoots0: Seq[Path], // Evaluated in order, first occurrence wins
1921
allowedInputs: Option[Set[os.Path]] = None,
20-
debugImporter: Boolean = false): Importer =
21-
new Importer {
22-
def resolve(docBase: Path, importName: String): Option[Path] =
23-
(docBase +: searchRoots0)
24-
.flatMap(base =>
25-
os.FilePath(importName) match {
26-
case r: os.SubPath => Some(base.asInstanceOf[OsPath].p / r)
27-
case r: os.RelPath =>
28-
if (r.ups > base.segmentCount()) None
29-
else Some(base.asInstanceOf[OsPath].p / r)
30-
case a: os.Path => Some(a)
31-
}
32-
)
33-
.filter(p => {
34-
val allowed = allowedInputs.fold(true)(_(p))
35-
if (debugImporter) {
36-
if (allowed) System.err.println(s"[import $importName] candidate $p")
37-
else
38-
System.err.println(
39-
s"[import $importName] excluded $p because it's not in $allowedInputs"
40-
)
41-
}
42-
allowed
43-
})
44-
.find(f => os.exists(f) && !os.isDir(f))
45-
.orElse({
46-
if (debugImporter) {
47-
System.err.println(s"[import $importName] none of the candidates exist")
48-
}
49-
None
50-
})
51-
.flatMap(p => {
52-
if (debugImporter) {
22+
debugImporter: Boolean = false)
23+
extends Importer {
24+
def resolve(docBase: Path, importName: String): Option[Path] =
25+
(docBase +: searchRoots0)
26+
.flatMap(base =>
27+
os.FilePath(importName) match {
28+
case r: os.SubPath => Some(base.asInstanceOf[OsPath].p / r)
29+
case r: os.RelPath =>
30+
if (r.ups > base.segmentCount()) None
31+
else Some(base.asInstanceOf[OsPath].p / r)
32+
case a: os.Path => Some(a)
33+
}
34+
)
35+
.filter(p => {
36+
val allowed = allowedInputs.fold(true)(_(p))
37+
if (debugImporter) {
38+
if (allowed) System.err.println(s"[import $importName] candidate $p")
39+
else
5340
System.err.println(
54-
s"[import $importName] $p is selected as it exists and is not a directory"
41+
s"[import $importName] excluded $p because it's not in $allowedInputs"
5542
)
56-
}
57-
Some(OsPath(p))
58-
})
43+
}
44+
allowed
45+
})
46+
.find(f => os.exists(f) && !os.isDir(f))
47+
.orElse({
48+
if (debugImporter) {
49+
System.err.println(s"[import $importName] none of the candidates exist")
50+
}
51+
None
52+
})
53+
.flatMap(p => {
54+
if (debugImporter) {
55+
System.err.println(
56+
s"[import $importName] $p is selected as it exists and is not a directory"
57+
)
58+
}
59+
Some(OsPath(p))
60+
})
5961

60-
def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = {
61-
readPath(path, binaryData, debugImporter)
62-
}
62+
def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = {
63+
readPath(path, binaryData, debugImporter)
6364
}
65+
}
6466

6567
def main0(
6668
args: Array[String],
@@ -70,7 +72,7 @@ object SjsonnetMainBase {
7072
stderr: PrintStream,
7173
wd: os.Path,
7274
allowedInputs: Option[Set[os.Path]] = None,
73-
importer: Option[(Path, String) => Option[os.Path]] = None,
75+
importer: Option[Importer] = None,
7476
std: Val.Obj = sjsonnet.stdlib.StdLibModule.Default.module): Int = {
7577

7678
var hasWarnings = false
@@ -94,7 +96,28 @@ object SjsonnetMainBase {
9496
autoPrintHelpAndExit = None
9597
)
9698
file <- Right(config.file)
97-
outputStr <- mainConfigured(file, config, parseCache, wd, allowedInputs, importer, warn, std)
99+
outputStr <- mainConfigured(
100+
file,
101+
config,
102+
new Settings(
103+
preserveOrder = config.preserveOrder.value,
104+
strict = config.strict.value,
105+
throwErrorForInvalidSets = config.throwErrorForInvalidSets.value,
106+
maxParserRecursionDepth = config.maxParserRecursionDepth,
107+
brokenAssertionLogic = config.brokenAssertionLogic.value
108+
),
109+
parseCache,
110+
wd,
111+
importer.getOrElse {
112+
new SimpleImporter(
113+
config.getOrderedJpaths.map(p => OsPath(os.Path(p, wd))),
114+
allowedInputs,
115+
debugImporter = config.debugImporter.value
116+
)
117+
},
118+
warn,
119+
std
120+
)
98121
res <- {
99122
if (hasWarnings && config.fatalWarnings.value) Left("")
100123
else Right(outputStr)
@@ -112,7 +135,6 @@ object SjsonnetMainBase {
112135
case Some(f) => os.write.over(os.Path(f, wd), str)
113136
}
114137
}
115-
116138
0
117139
}
118140
}
@@ -124,6 +146,14 @@ object SjsonnetMainBase {
124146
indent = config.indent,
125147
getCurrentPosition = getCurrentPosition
126148
)
149+
else if (config.expectString.value)
150+
new SimpleVisitor[Writer, Writer] {
151+
val expectedMsg = "expected string result"
152+
override def visitString(s: CharSequence, index: Int): Writer = {
153+
wr.write(s.toString)
154+
wr
155+
}
156+
}
127157
else new Renderer(wr, indent = config.indent)
128158

129159
private def handleWriteFile[T](f: => T): Either[String, T] =
@@ -141,7 +171,6 @@ object SjsonnetMainBase {
141171
case None =>
142172
val sw = new StringWriter
143173
materialize(sw).map(_ => sw.toString)
144-
145174
case Some(f) =>
146175
handleWriteFile(
147176
os.write.over.outputStream(os.Path(f, wd), createFolders = config.createDirs.value)
@@ -157,48 +186,34 @@ object SjsonnetMainBase {
157186
}
158187
}
159188

160-
private def expectString(v: ujson.Value) = v match {
161-
case ujson.Str(s) => Right(s)
162-
case _ => Left("expected string result, got: " + v.getClass)
163-
}
164-
165189
private def renderNormal(
166190
config: Config,
167191
interp: Interpreter,
168192
jsonnetCode: String,
169193
path: os.Path,
170194
wd: os.Path,
171195
getCurrentPosition: () => Position) = {
172-
writeToFile(config, wd)(writer =>
173-
if (config.expectString.value) {
174-
val res = interp.interpret(jsonnetCode, OsPath(path)).flatMap(expectString)
175-
res match {
176-
case Right(s) => writer.write(s)
177-
case _ =>
178-
}
179-
res
180-
} else {
181-
val renderer = rendererForConfig(writer, config, getCurrentPosition)
182-
val res = interp.interpret0(jsonnetCode, OsPath(path), renderer)
183-
if (config.yamlOut.value) writer.write('\n')
184-
res
185-
}
186-
)
196+
writeToFile(config, wd) { writer =>
197+
val renderer = rendererForConfig(writer, config, getCurrentPosition)
198+
val res = interp.interpret0(jsonnetCode, OsPath(path), renderer)
199+
if (config.yamlOut.value) writer.write('\n')
200+
res
201+
}
187202
}
188203

189204
private def isScalar(v: ujson.Value) = !v.isInstanceOf[ujson.Arr] && !v.isInstanceOf[ujson.Obj]
190205

191-
private def parseBindings(
206+
def parseBindings(
192207
strs: Seq[String],
193208
strFiles: Seq[String],
194209
codes: Seq[String],
195210
codeFiles: Seq[String],
196-
wd: os.Path) = {
211+
wd: os.Path): Map[String, String] = {
197212

198213
def split(s: String) = s.split("=", 2) match {
199214
case Array(x) => (x, System.getenv(x))
200215
case Array(x, v) => (x, v)
201-
case _ => ???
216+
case _ => throw new IllegalArgumentException("invalid binding: " + s)
202217
}
203218

204219
def splitMap(s: Seq[String], f: String => String) =
@@ -223,15 +238,16 @@ object SjsonnetMainBase {
223238
* Right(str) if there's some string that needs to be printed to stdout or --output-file,
224239
* Left(err) if there is an error to be reported
225240
*/
226-
private def mainConfigured(
241+
def mainConfigured(
227242
file: String,
228243
config: Config,
244+
settings: Settings,
229245
parseCache: ParseCache,
230246
wd: os.Path,
231-
allowedInputs: Option[Set[os.Path]],
232-
importer: Option[(Path, String) => Option[os.Path]],
233-
warnLogger: (Boolean, String) => Unit,
234-
std: Val.Obj): Either[String, String] = {
247+
importer: Importer,
248+
warnLogger: Evaluator.Logger,
249+
std: Val.Obj,
250+
evaluatorOverride: Option[Evaluator] = None): Either[String, String] = {
235251

236252
val (jsonnetCode, path) =
237253
if (config.exec.value) (file, wd / Util.wrapInLessThanGreaterThan("exec"))
@@ -265,35 +281,23 @@ object SjsonnetMainBase {
265281
queryExtVar = (key: String) => extBinding.get(key).map(ExternalVariable.code),
266282
queryTlaVar = (key: String) => tlaBinding.get(key).map(ExternalVariable.code),
267283
OsPath(wd),
268-
importer = importer match {
269-
case Some(i) =>
270-
new Importer {
271-
def resolve(docBase: Path, importName: String): Option[Path] =
272-
i(docBase, importName).map(OsPath.apply)
273-
def read(path: Path, binaryData: Boolean): Option[ResolvedFile] = {
274-
readPath(path, binaryData)
275-
}
276-
}
277-
case None =>
278-
resolveImport(
279-
config.getOrderedJpaths.map(os.Path(_, wd)).map(OsPath.apply),
280-
allowedInputs,
281-
config.debugImporter.value
282-
)
283-
},
284+
importer = importer,
284285
parseCache,
285-
settings = new Settings(
286-
preserveOrder = config.preserveOrder.value,
287-
strict = config.strict.value,
288-
throwErrorForInvalidSets = config.throwErrorForInvalidSets.value,
289-
maxParserRecursionDepth = config.maxParserRecursionDepth,
290-
brokenAssertionLogic = config.brokenAssertionLogic.value
291-
),
286+
settings = settings,
292287
storePos = (position: Position) => if (config.yamlDebug.value) currentPos = position else (),
293288
logger = warnLogger,
294289
std = std,
295290
variableResolver = _ => None
296-
)
291+
) {
292+
override def createEvaluator(
293+
resolver: CachedResolver,
294+
extVars: String => Option[Expr],
295+
wd: Path,
296+
settings: Settings): Evaluator =
297+
evaluatorOverride.getOrElse(
298+
super.createEvaluator(resolver, extVars, wd, settings)
299+
)
300+
}
297301

298302
(config.multi, config.yamlStream.value) match {
299303
case (Some(multiPath), _) =>
@@ -303,14 +307,10 @@ object SjsonnetMainBase {
303307
obj.value.toSeq.map { case (f, v) =>
304308
for {
305309
rendered <- {
306-
if (config.expectString.value) {
307-
expectString(v)
308-
} else {
309-
val writer = new StringWriter()
310-
val renderer = rendererForConfig(writer, config, () => currentPos)
311-
ujson.transform(v, renderer)
312-
Right(writer.toString)
313-
}
310+
val writer = new StringWriter()
311+
val renderer = rendererForConfig(writer, config, () => currentPos)
312+
ujson.transform(v, renderer)
313+
Right(writer.toString)
314314
}
315315
relPath = (os.FilePath(multiPath) / os.RelPath(f)).asInstanceOf[os.FilePath]
316316
_ <- writeFile(config, relPath.resolveFrom(wd), rendered)

0 commit comments

Comments
 (0)