Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions cmd/rest/lua.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package main

import (
"errors"
"fmt"

"github.com/taybart/rest"
"github.com/taybart/rest/client"
restlua "github.com/taybart/rest/lua"
"github.com/taybart/rest/request"
lua "github.com/yuin/gopher-lua"
)

var rclient *client.Client
var exportsTable *lua.LTable

func syncExportsTable(l *lua.LState, f *rest.Rest) error {
// Get the "exports" field from the rest table
exportsValue := l.GetField(l.GetGlobal("rest"), "exports")
var ok bool
exportsTable, ok = exportsValue.(*lua.LTable)
if !ok {
return errors.New("rest.exports is not a table")
}
f.Parser.AddExportsCtx(restlua.LTableToMap(exportsTable))

return nil
}

func do(f *rest.Rest, req request.Request) (map[string]any, error) {
if req.Skip {
return nil, errors.New("request marked as skip = true")
}

_, exports, err := rclient.Do(req)
if err != nil {
return nil, err
}

// make sure to add the exports back into parsers ctx
f.Parser.AddExportsCtx(exports)
return exports, nil
}

func populateGlobalObject(l *lua.LState, f *rest.Rest) error {

if exportsTable == nil {
exportsTable = l.NewTable()
}

lDoFile := func(l *lua.LState) int {
ignoreFail := l.ToBool(1) /* get argument */
if err := syncExportsTable(l, f); err != nil {
panic(err)
}

if err := f.RunFile(ignoreFail); err != nil {
panic(err)
}
return 0 /* number of results */
}
lDoLabel := func(l *lua.LState) int {
label := l.ToString(1) /* get argument */
if err := syncExportsTable(l, f); err != nil {
panic(err)
}
req, err := f.Request(label)
if err != nil {
panic(err)
}
exports, err := do(f, req)
if err != nil {
panic(err)
}
// set all exports from running the request,
// this will reset on every run for the cli context not the client context
l.SetField(l.GetGlobal("rest"), "exports", restlua.MapToLTable(l, exports))
return 0 /* number of results */
}
lDoIndex := func(l *lua.LState) int {
idx := l.ToInt(1) /* get argument */
if err := syncExportsTable(l, f); err != nil {
panic(err)
}
req, err := f.RequestByIndex(idx)
if err != nil {
panic(err)
}
exports, err := do(f, req)
if err != nil {
panic(err)
}
// set all exports from running the request,
// this will reset on every run for the cli context not the client context
l.SetField(l.GetGlobal("rest"), "exports", restlua.MapToLTable(l, exports))
return 0 /* number of results */
}

l.SetGlobal("rest", restlua.MakeLTable(l, map[string]lua.LValue{
"file": l.NewFunction(lDoFile),
"label": l.NewFunction(lDoLabel),
"block": l.NewFunction(lDoIndex),
"exports": exportsTable,
}))

return nil
}

func execute(l *lua.LState, code string) error {
if err := l.DoString(code); err != nil {
return restlua.FmtError(code, err)
}
return nil
}

func runCLITool(f *rest.Rest) error {
cli, err := f.Parser.CLI()
if err != nil {
return err
}

// TODO: do something with flags

if f.Parser.Root.CLI.Loop == nil {
return errors.New("no loop defined")
}

rclient, err = client.New(f.Parser.Config)
if err != nil {
return err
}
l := lua.NewState()
defer l.Close()

if err := restlua.RegisterModules(l); err != nil {
return err
}
if err := populateGlobalObject(l, f); err != nil {
return err
}

if err := execute(l, fmt.Sprintf(`
%s
while true do
io.write('> ')
local input = io.read()
%s
end`, *cli.LoopSetup, *cli.Loop)); err != nil {
return err
}

return nil
}
16 changes: 7 additions & 9 deletions cmd/rest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,12 @@ func usage(u args.Usage) {
}

var usage strings.Builder
usage.WriteString(
fmt.Sprintf("%s\t\t=== Rest Easy ===\n%s",
log.BoldBlue, log.Reset))
usage.WriteString(
fmt.Sprintf("%sCLI:\n%s", log.BoldGreen, log.Reset))
fmt.Fprintf(&usage, "%s\t\t=== Rest Easy ===\n%s", log.BoldBlue, log.Reset)
fmt.Fprintf(&usage, "%sCLI:\n%s", log.BoldGreen, log.Reset)
u.BuildFlagString(&usage, cli)
usage.WriteString(
fmt.Sprintf("%sServer:\n%s", log.BoldGreen, log.Reset))
fmt.Fprintf(&usage, "%sServer:\n%s", log.BoldGreen, log.Reset)
u.BuildFlagString(&usage, server)
usage.WriteString(
fmt.Sprintf("%sClient:\n%s", log.BoldGreen, log.Reset))
fmt.Fprintf(&usage, "%sClient:\n%s", log.BoldGreen, log.Reset)
u.BuildFlagString(&usage, client)
fmt.Println(usage.String())
}
Expand Down Expand Up @@ -255,6 +250,9 @@ func run() error {
log.Debug("running request", c.Label, "on file", c.File)
return f.RunLabel(c.Label)
} else {
if f.Parser.Root.CLI != nil {
return runCLITool(f)
}
log.Debug("running file", c.File)
return f.RunFile(c.IgnoreFail)
}
Expand Down
14 changes: 14 additions & 0 deletions doc/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
}
```


### Routes

There are some build in routes to help with testing and development:

- `/__echo__` - returns the request headers (only ones that begin with `x-`) and body
- `/__ws__` - echo websocket messages back
- `/__quit__` - exit the server process

NOTE: these routes will be overridden if you provide a custom handler


### Examples

```sh
Expand Down Expand Up @@ -70,6 +82,8 @@ server {
spa = true
# if you need a more complicated test server you can add specific handlers
handler "GET" "/path" {
# override responses and just serve a websocket echo path
ws = true
# either use lua to create a more complex response
fn = "similar concept to the after hook in the client files (see hander fns below)"
# or use a response object to just have different responses per path
Expand Down
30 changes: 30 additions & 0 deletions file/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package file

import (
"encoding/base64"
"errors"
"fmt"
"net/url"
"os"
Expand Down Expand Up @@ -266,6 +267,35 @@ func makeNanoIDFunc() function.Function {
})
}

func makeTryExportsFunc(exports map[string]cty.Value) function.Function {
return function.New(&function.Spec{
VarParam: &function.Parameter{
Name: "args",
Type: cty.DynamicPseudoType,
},
Type: function.StaticReturnType(cty.String),
Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) {

if args[0].IsNull() {
return cty.NilVal, errors.New("exports key is required")
}

key := args[0].AsString()

_default := ""

// Handle optional alphabet argument
if len(args) > 1 && !args[1].IsNull() {
_default = args[1].AsString()
}
if val, ok := exports[key]; ok {
return val, nil
}
return cty.StringVal(_default), nil
},
})
}

func makeGoTemplateFunc() function.Function {
return function.New(&function.Spec{
Params: []function.Parameter{
Expand Down
86 changes: 74 additions & 12 deletions file/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,29 @@ type HCLRequest struct {
BlockIndex int
}

type CLIFlag struct {
Desc string `hcl:"desc"`
RequestID string `hcl:"request_id"`
Bool bool `hcl:"bool"`
// Key string
}

type CLI struct {
Loop *string `hcl:"loop,optional"`
LoopSetup *string `hcl:"loop_setup,optional"`
Flags map[string]CLIFlag
FlagsBody hcl.Body `hcl:",remain"`
}

type Root struct {
filename string

Imports *[]string `hcl:"imports"`
Exports *[]string `hcl:"exports"`

// TODO: allow for read() in this context
CLI *CLI `hcl:"cli,block"`

Locals []*struct {
Body hcl.Body `hcl:",remain"`
} `hcl:"locals,block"`
Expand Down Expand Up @@ -153,6 +170,49 @@ func (p *Parser) read(filename string, root *Root) error {
return nil
}

func (p *Parser) CLI() (CLI, error) {
if p.Root.CLI == nil {
return CLI{}, errors.New("cli block not found")
}

cli := p.Root.CLI
if cli.FlagsBody != nil {
cli.Flags = make(map[string]CLIFlag)
attrs, diags := cli.FlagsBody.JustAttributes()
if diags.HasErrors() {
p.writeDiags(diags)
return *cli, errors.New("error parsing cli flags")
}
flagsAttr, ok := attrs["flags"]
if !ok {
return *cli, nil // no flags defined
}

val, diags := flagsAttr.Expr.Value(p.Ctx)
if diags.HasErrors() {
p.writeDiags(diags)
return *cli, errors.New("error evaluating flags")
}

if val.Type().IsObjectType() || val.Type().IsMapType() {
for name, flagVal := range val.AsValueMap() {
var flag CLIFlag
if flagVal.Type().HasAttribute("desc") {
flag.Desc = flagVal.GetAttr("desc").AsString()
}
if flagVal.Type().HasAttribute("request_id") {
flag.RequestID = flagVal.GetAttr("request_id").AsString()
}
if flagVal.Type().HasAttribute("bool") {
flag.Bool = true
}
cli.Flags[name] = flag
}
}
}
return *cli, nil
}

func (p *Parser) Socket() (request.Socket, error) {
var sock request.Socket
if p.Root.Socket == nil {
Expand Down Expand Up @@ -327,6 +387,7 @@ func exportsToCty(exports map[string]any) map[string]cty.Value {
func (p *Parser) AddExportsCtx(exports map[string]any) {
p.Exports = exportsToCty(exports)
p.Ctx.Variables["exports"] = cty.ObjectVal(p.Exports)
p.Ctx.Functions["try_exports"] = makeTryExportsFunc(p.Exports)
}

func (p *Parser) makeContext() {
Expand All @@ -336,18 +397,19 @@ func (p *Parser) makeContext() {
"exports": cty.ObjectVal(p.Exports),
},
Functions: map[string]function.Function{
"b64_dec": makeBase64DecodeFunc(),
"b64_enc": makeBase64EncodeFunc(),
"btmpl": makeTemplateFunc(),
"env": makeEnvFunc(),
"form": makeFormFunc(),
"json_dec": makeJSONDecodeFunc(),
"json_enc": makeJSONEncodeFunc(),
"nanoid": makeNanoIDFunc(),
"read": makeFileReadFunc(),
"tmpl": makeGoTemplateFunc(),
"trim": makeTrimFunc(),
"uuid": makeUUIDFunc(),
"b64_dec": makeBase64DecodeFunc(),
"b64_enc": makeBase64EncodeFunc(),
"btmpl": makeTemplateFunc(),
"env": makeEnvFunc(),
"form": makeFormFunc(),
"json_dec": makeJSONDecodeFunc(),
"json_enc": makeJSONEncodeFunc(),
"nanoid": makeNanoIDFunc(),
"read": makeFileReadFunc(),
"tmpl": makeGoTemplateFunc(),
"try_exports": makeTryExportsFunc(p.Exports),
"trim": makeTrimFunc(),
"uuid": makeUUIDFunc(),
},
}
}
Expand Down
Loading