diff --git a/src/lowered/codegen-base.mc b/src/lowered/codegen-base.mc index a199a12..76f73da 100644 --- a/src/lowered/codegen-base.mc +++ b/src/lowered/codegen-base.mc @@ -7,7 +7,7 @@ include "mexpr/lamlift.mc" lang ProbTimeCodegenBase = MExprAst + MExprLambdaLift + ProbTimeAst type PTCompileEnv = { ast : Expr, - llSolutions : Map Name LambdaLiftSolution, + llSolutions : Map Name FinalOrderedLamLiftSolution, aliases : Map Name PTType, consts : Map Name PTExpr, topVarEnv : Map String Name, diff --git a/src/lowered/codegen.mc b/src/lowered/codegen.mc index 8f0ab6e..b2443a9 100644 --- a/src/lowered/codegen.mc +++ b/src/lowered/codegen.mc @@ -493,17 +493,8 @@ lang ProbTimeCodegenSystem = sem getCapturedTopLevelVars : Info -> PTCompileEnv -> Name -> [Expr] sem getCapturedTopLevelVars info env = | id -> - match mapLookup id env.llSolutions with Some argMap then - let argIds = mapKeys argMap.vars in - map - (lam id. - let s = nameGetStr id in - match mapLookup s env.topVarEnv with Some topLevelId then - nvar_ topLevelId - else - errorSingle [info] - (concat "Could not find top-level binding of parameter " (nameGetStr id))) - argIds + match mapLookup id env.llSolutions with Some sol then + map (lam x. nvar_ x.0) sol.vars else errorSingle [info] (concat "Could not find lambda lifted arguments for task " (nameGetStr id)) diff --git a/src/rtppl.mc b/src/rtppl.mc index 1175043..1bbe27b 100644 --- a/src/rtppl.mc +++ b/src/rtppl.mc @@ -13,6 +13,7 @@ include "stdlib::mexpr/shallow-patterns.mc" include "stdlib::mexpr/type-check.mc" include "stdlib::ocaml/mcore.mc" include "stdlib::tuning/hole-cfa.mc" +include "stdlib::mexpr/generate-utest.mc" include "coreppl::dppl-arg.mc" include "coreppl::infer-method.mc" @@ -20,17 +21,12 @@ include "coreppl::parser.mc" include "coreppl::coreppl-to-mexpr/compile.mc" include "coreppl::coreppl-to-mexpr/runtimes.mc" -let _rts = lam. - use LoadRuntime in - let _bpf = BPF {particles = int_ 1} in - let _bpfRtEntry = loadRuntimeEntry _bpf "smc-bpf/runtime.mc" in - combineInferRuntimes default (mapFromSeq cmpInferMethod [(_bpf, _bpfRtEntry)]) - lang ProbTimeCompileLang = ProbTimeLower + ProbTimeSym + ProbTimePrettyPrint + ProbTimeValidate + ProbTimeCodegen + RtpplPrettyPrint + ProbTimeJson + - DPPLParser + MExprLowerNestedPatterns + MExprTypeCheck + MCoreCompileLang + CPPLLoader + MExprAst + StripUtestLoader + MExprLowerNestedPatterns + + MCoreCompileLang sem buildProbTime : RtpplOptions -> PTProgram -> CompileResult -> () sem buildProbTime options program = @@ -47,9 +43,16 @@ lang ProbTimeCompileLang = sem buildTaskDppl : RtpplOptions -> String -> Expr -> () sem buildTaskDppl options path = | taskAst -> - let runtimeData = _rts () in - let dpplOpts = {default with cps = "partial", extractSimplification = "inline"} in - let taskAst = mexprCompile dpplOpts runtimeData taskAst in + let loader = mkLoader symEnvDefault typcheckEnvDefault [StripUtestHook ()] in + let dpplOpts = {defaultArgs with cps = "partial", extractSimplification = "inline"} in + let loader = enableCPPLCompilation dpplOpts loader in + recursive let f = lam decls. lam ast. + match exprAsDecl ast with Some (decl, ast) + then f (snoc decls decl) ast + else snoc decls (decl_nulet_ (nameSym "") ast) in + match f [] taskAst with decls in + let loader = foldl _addDeclExn loader decls in + let taskAst = buildFullAst loader in buildTaskMExpr options path taskAst sem buildTaskMExpr : RtpplOptions -> String -> Expr -> () @@ -66,7 +69,6 @@ lang ProbTimeCompileLang = p.cleanup() in writeIntermediateMExprIf path taskAst options.debugCompileMExpr; - let taskAst = typeCheck taskAst in let taskAst = lowerAll taskAst in compileMCore taskAst (mkEmptyHooks compileOCaml)