diff --git a/README.md b/README.md index 0dd095d..9798070 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ Converts Go structs with [go-validator](https://github.com/go-playground/validat Zen supports self-referential types and generic types. Other cyclic types (apart from self referential types) are not supported as they are not supported by zod itself. +Zen emits Zod v4 schemas by default. Use `zen.WithZodV3()` if you need the previous output style for snapshot compatibility or incremental migration. + ## Usage ```go @@ -58,6 +60,19 @@ c.AddType(PairMap[string, int, bool]{}) fmt.Print(c.Export()) ``` +Legacy v3-compatible output is still available: + +```go +fmt.Print(zen.StructToZodSchema(User{}, zen.WithZodV3())) +``` + +The main migration differences are: + +- string format tags such as `email`, `http_url`, `ipv4`, `uuid4`, and `md5` now use Zod v4 helpers like `z.email()`, `z.httpUrl()`, `z.ipv4()`, `z.uuid({ version: "v4" })`, and `z.hash("md5")` +- `ip` and `ip_addr` now emit `z.union([z.ipv4(), z.ipv6()])` +- embedded anonymous structs now expand through `.shape` spreads instead of `.merge(...)` +- enum-like map keys now emit `z.partialRecord(...)` + Outputs: ```typescript @@ -267,7 +282,8 @@ export const RequestSchema = z.object({ end: z.number().gt(0).optional(), }).refine((val) => !val.start || !val.end || val.start < val.end, 'Start should be less than end'), search: z.string().refine((val) => !val || /^[a-z0-9_]*$/.test(val), 'Invalid search identifier').optional(), -}).merge(SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])})) + ...SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])}).shape, +}) export type Request = z.infer ``` diff --git a/zod.go b/zod.go index 84c627b..a0c7aae 100644 --- a/zod.go +++ b/zod.go @@ -49,6 +49,13 @@ func WithIgnoreTags(ignores ...string) Opt { } } +// Emits legacy Zod v3-compatible schemas instead of the default Zod v4 output. +func WithZodV3() Opt { + return func(c *Converter) { + c.zodV3 = true + } +} + // NewConverterWithOpts initializes and returns a new converter instance. func NewConverterWithOpts(opts ...Opt) *Converter { c := &Converter{ @@ -159,11 +166,26 @@ type meta struct { selfRef bool } +type stringSchemaParts struct { + base string + chain string + enumLike bool + isIPUnion bool +} + +type stringSchemaChunk struct { + kind string + text string + v4Base string + legacyChain string +} + type Converter struct { prefix string customTypes map[string]CustomFn customTags map[string]CustomFn ignoreTags []string + zodV3 bool structs int outputs map[string]entry stack []meta @@ -288,12 +310,10 @@ func (c *Converter) getStructShape(input reflect.Type, indent int) string { optional := isOptional(field) nullable := isNullable(field) - line, shouldMerge := c.convertField(field, indent+1, optional, nullable) - - if !shouldMerge { - output.WriteString(line) + if field.Anonymous { + output.WriteString(c.convertEmbeddedFieldSpread(field, indent+1)) } else { - output.WriteString(fmt.Sprintf("%s...%s.shape,\n", indentation(indent+1), schemaName(c.prefix, typeName(field.Type)))) + output.WriteString(c.convertNamedField(field, indent+1, optional, nullable)) } } @@ -310,6 +330,7 @@ func (c *Converter) convertStruct(input reflect.Type, indent int) string { `) merges := []string{} + embeddedFields := []string{} fields := input.NumField() for i := 0; i < fields; i++ { @@ -317,12 +338,25 @@ func (c *Converter) convertStruct(input reflect.Type, indent int) string { optional := isOptional(field) nullable := isNullable(field) - line, shouldMerge := c.convertField(field, indent+1, optional, nullable) + if field.Anonymous { + if c.zodV3 { + line, shouldMerge := c.convertEmbeddedFieldMerge(field, indent+1) + if shouldMerge { + merges = append(merges, line) + } else { + output.WriteString(line) + } + } else { + embeddedFields = append(embeddedFields, c.convertEmbeddedFieldSpread(field, indent+1)) + } + } else { + output.WriteString(c.convertNamedField(field, indent+1, optional, nullable)) + } + } - if !shouldMerge { + if !c.zodV3 { + for _, line := range embeddedFields { output.WriteString(line) - } else { - merges = append(merges, line) } } @@ -490,9 +524,16 @@ func (c *Converter) ConvertType(t reflect.Type, validate string, indent int) str if validate != "" { switch zodType { case "string": - validateStr = c.validateString(validate) - if strings.Contains(validateStr, ".enum(") { - return "z" + validateStr + stringParts := c.validateString(validate) + switch { + case stringParts.enumLike: + return stringParts.base + stringParts.chain + case stringParts.isIPUnion: + return stringParts.base + case stringParts.base != "": + return stringParts.base + stringParts.chain + default: + return "z.string()" + stringParts.chain } case "number": validateStr = c.validateNumber(validate) @@ -538,12 +579,12 @@ func (c *Converter) getType(t reflect.Type, indent int) string { return zodType } -func (c *Converter) convertField(f reflect.StructField, indent int, optional, nullable bool) (string, bool) { +func (c *Converter) convertNamedField(f reflect.StructField, indent int, optional, nullable bool) string { name := fieldName(f) // fields named `-` are not exported to JSON so don't export zod types if name == "-" { - return "", false + return "" } // because nullability is processed before custom types, this makes sure @@ -561,24 +602,37 @@ func (c *Converter) convertField(f reflect.StructField, indent int, optional, nu } t := c.ConvertType(f.Type, f.Tag.Get("validate"), indent) - if !f.Anonymous { - return fmt.Sprintf( - "%s%s: %s%s%s,\n", - indentation(indent), - name, - t, - optionalCall, - nullableCall), false - } else { - typeName := typeName(f.Type) - entry, ok := c.outputs[typeName] - if ok && entry.selfRef { - // Since we are spreading shape, we won't be able to support any validation tags on the embedded field - return fmt.Sprintf("%s...%s,\n", indentation(indent), shapeName(c.prefix, typeName)), false - } + return fmt.Sprintf( + "%s%s: %s%s%s,\n", + indentation(indent), + name, + t, + optionalCall, + nullableCall) +} + +func (c *Converter) convertEmbeddedFieldMerge(f reflect.StructField, indent int) (string, bool) { + t := c.ConvertType(f.Type, f.Tag.Get("validate"), indent) + typeName := typeName(f.Type) + entry, ok := c.outputs[typeName] + if ok && entry.selfRef { + // Since we are spreading shape, we won't be able to support any validation tags on the embedded field + return fmt.Sprintf("%s...%s,\n", indentation(indent), shapeName(c.prefix, typeName)), false + } + + return fmt.Sprintf(".merge(%s)", t), true +} - return fmt.Sprintf(".merge(%s)", t), true +func (c *Converter) convertEmbeddedFieldSpread(f reflect.StructField, indent int) string { + t := c.ConvertType(f.Type, f.Tag.Get("validate"), indent) + typeName := typeName(f.Type) + entry, ok := c.outputs[typeName] + if ok && entry.selfRef { + // Since we are spreading shape, we won't be able to support any validation tags on the embedded field + return fmt.Sprintf("%s...%s,\n", indentation(indent), shapeName(c.prefix, typeName)) } + + return fmt.Sprintf("%s...%s.shape,\n", indentation(indent), t) } func (c *Converter) getTypeField(f reflect.StructField, indent int, optional, nullable bool) (string, bool) { @@ -716,9 +770,16 @@ func (c *Converter) convertKeyType(t reflect.Type, validate string) string { if validate != "" { switch zodType { case "string": - validateStr = c.validateString(validate) - if strings.Contains(validateStr, ".enum(") { - return "z" + validateStr + stringParts := c.validateString(validate) + switch { + case stringParts.enumLike: + return stringParts.base + stringParts.chain + case stringParts.isIPUnion: + return stringParts.base + case stringParts.base != "": + return stringParts.base + stringParts.chain + default: + return "z.string()" + stringParts.chain } case "number": validateStr = c.validateNumber(validate) @@ -787,8 +848,15 @@ forParts: validateStr.WriteString(refine) } - return fmt.Sprintf(`z.record(%s, %s)%s`, - c.convertKeyType(t.Key(), getValidateKeys(validate)), + keySchema := c.convertKeyType(t.Key(), getValidateKeys(validate)) + recordFn := "z.record" + if !c.zodV3 && isPartialRecordKeySchema(keySchema) { + recordFn = "z.partialRecord" + } + + return fmt.Sprintf(`%s(%s, %s)%s`, + recordFn, + keySchema, c.ConvertType(t.Elem(), getValidateValues(validate), indent), validateStr.String()) } @@ -923,29 +991,41 @@ func (c *Converter) validateNumber(validate string) string { return validateStr.String() } -func (c *Converter) validateString(validate string) string { - var validateStr strings.Builder +func (c *Converter) validateString(validate string) stringSchemaParts { + var chunks []stringSchemaChunk var refines []string parts := strings.Split(validate, ",") - for _, part := range parts { - valName, valValue, done := c.preprocessValidationTagPart(part, &refines, &validateStr) - if done { + for _, rawPart := range parts { + valName, valValue, skip := c.parseValidationTagPart(rawPart) + if skip { + continue + } + + if h, ok := c.customTags[valName]; ok { + v := h(c, reflect.TypeOf(0), valValue, 0) + if strings.HasPrefix(v, ".refine") { + refines = append(refines, v) + } else { + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: v}) + } continue } if valValue != "" { switch valName { case "oneof": - vals := splitParamsRegex.FindAllString(part[6:], -1) + vals := splitParamsRegex.FindAllString(rawPart[6:], -1) for i := 0; i < len(vals); i++ { vals[i] = strings.Replace(vals[i], "'", "", -1) } if len(vals) == 0 { panic("oneof= must be followed by a list of values") } - // const FishEnum = z.enum(["Salmon", "Tuna", "Trout"]); - validateStr.WriteString(fmt.Sprintf(".enum([\"%s\"] as const)", strings.Join(vals, "\", \""))) + chunks = append(chunks, stringSchemaChunk{ + kind: "enum", + text: fmt.Sprintf("z.enum([\"%s\"] as const)", strings.Join(vals, "\", \"")), + }) case "len": refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length === %s, 'String must contain %s character(s)')", valValue, valValue)) case "min": @@ -969,137 +1049,215 @@ func (c *Converter) validateString(validate string) string { case "lte": refines = append(refines, fmt.Sprintf(".refine((val) => [...val].length <= %s, 'String must contain at most %s character(s)')", valValue, valValue)) case "contains": - validateStr.WriteString(fmt.Sprintf(".includes(\"%s\")", valValue)) + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".includes(\"%s\")", valValue)}) case "endswith": - validateStr.WriteString(fmt.Sprintf(".endsWith(\"%s\")", valValue)) + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".endsWith(\"%s\")", valValue)}) case "startswith": - validateStr.WriteString(fmt.Sprintf(".startsWith(\"%s\")", valValue)) + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".startsWith(\"%s\")", valValue)}) case "eq": refines = append(refines, fmt.Sprintf(".refine((val) => val === \"%s\")", valValue)) case "ne": refines = append(refines, fmt.Sprintf(".refine((val) => val !== \"%s\")", valValue)) - default: - panic(fmt.Sprintf("unknown validation: %s", part)) + panic(fmt.Sprintf("unknown validation: %s", rawPart)) } - } else { - switch part { - case "omitempty": - case "required": - validateStr.WriteString(".min(1)") - case "email": - // email is more readable than copying the regex in regexes.go but could be incompatible - // Also there is an open issue https://github.com/go-playground/validator/issues/517 - // https://github.com/puellanivis/pedantic-regexps/blob/master/email.go - // solution is there in the comments but not implemented yet - validateStr.WriteString(".email()") - case "url": - // url is more readable than copying the regex in regexes.go but could be incompatible - validateStr.WriteString(".url()") - case "ipv4": - validateStr.WriteString(".ip({ version: \"v4\" })") - case "ip4_addr": - validateStr.WriteString(".ip({ version: \"v4\" })") - case "ipv6": - validateStr.WriteString(".ip({ version: \"v6\" })") - case "ip6_addr": - validateStr.WriteString(".ip({ version: \"v6\" })") - case "ip": - validateStr.WriteString(".ip()") - case "ip_addr": - validateStr.WriteString(".ip()") - case "http_url": - // url is more readable than copying the regex in regexes.go but could be incompatible - validateStr.WriteString(".url()") - case "url_encoded": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uRLEncodedRegexString)) - case "alpha": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", alphaRegexString)) - case "alphanum": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", alphaNumericRegexString)) - case "alphanumunicode": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", alphaUnicodeNumericRegexString)) - case "alphaunicode": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", alphaUnicodeRegexString)) - case "ascii": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", aSCIIRegexString)) - case "boolean": - validateStr.WriteString(".enum(['true', 'false'])") - case "lowercase": - refines = append(refines, ".refine((val) => val === val.toLowerCase())") - case "number": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", numberRegexString)) - case "numeric": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", numericRegexString)) - case "uppercase": - refines = append(refines, ".refine((val) => val === val.toUpperCase())") - case "base64": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", base64RegexString)) - case "mongodb": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", mongodbRegexString)) - case "datetime": - validateStr.WriteString(".datetime()") - case "hexadecimal": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", hexadecimalRegexString)) - case "json": - // TODO: Better error messages with this - // const literalSchema = z.union([z.string(), z.number(), z.boolean(), z.null()]); - //type Literal = z.infer; - //type Json = Literal | { [key: string]: Json } | Json[]; - //const jsonSchema: z.ZodType = z.lazy(() => - // z.union([literalSchema, z.array(jsonSchema), z.record(jsonSchema)]) - //); - // - //jsonSchema.parse(data); - - refines = append(refines, ".refine((val) => { try { JSON.parse(val); return true } catch { return false } })") - case "jwt": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", jWTRegexString)) - case "latitude": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", latitudeRegexString)) - case "longitude": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", longitudeRegexString)) - case "uuid": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUIDRegexString)) - case "uuid3": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID3RegexString)) - case "uuid3_rfc4122": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID3RFC4122RegexString)) - case "uuid4": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID4RegexString)) - case "uuid4_rfc4122": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID4RFC4122RegexString)) - case "uuid5": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID5RegexString)) - case "uuid5_rfc4122": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUID5RFC4122RegexString)) - case "uuid_rfc4122": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", uUIDRFC4122RegexString)) - case "md4": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", md4RegexString)) - case "md5": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", md5RegexString)) - case "sha256": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", sha256RegexString)) - case "sha384": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", sha384RegexString)) - case "sha512": - validateStr.WriteString(fmt.Sprintf(".regex(/%s/)", sha512RegexString)) + continue + } - default: - panic(fmt.Sprintf("unknown validation: %s", part)) - } + switch valName { + case "omitempty": + case "required": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: ".min(1)"}) + case "email": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.email()", legacyChain: ".email()"}) + case "url": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.url()", legacyChain: ".url()"}) + case "ipv4", "ip4_addr": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.ipv4()", legacyChain: `.ip({ version: "v4" })`}) + case "ipv6", "ip6_addr": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.ipv6()", legacyChain: `.ip({ version: "v6" })`}) + case "ip", "ip_addr": + chunks = append(chunks, stringSchemaChunk{kind: "ip", legacyChain: ".ip()"}) + case "http_url": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.httpUrl()", legacyChain: ".url()"}) + case "url_encoded": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", uRLEncodedRegexString)}) + case "alpha": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", alphaRegexString)}) + case "alphanum": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", alphaNumericRegexString)}) + case "alphanumunicode": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", alphaUnicodeNumericRegexString)}) + case "alphaunicode": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", alphaUnicodeRegexString)}) + case "ascii": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", aSCIIRegexString)}) + case "boolean": + chunks = append(chunks, stringSchemaChunk{kind: "enum", text: "z.enum(['true', 'false'])"}) + case "lowercase": + refines = append(refines, ".refine((val) => val === val.toLowerCase())") + case "number": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", numberRegexString)}) + case "numeric": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", numericRegexString)}) + case "uppercase": + refines = append(refines, ".refine((val) => val === val.toUpperCase())") + case "base64": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.base64()", legacyChain: fmt.Sprintf(".regex(/%s/)", base64RegexString)}) + case "mongodb": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", mongodbRegexString)}) + case "datetime": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.iso.datetime()", legacyChain: ".datetime()"}) + case "hexadecimal": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.hex()", legacyChain: fmt.Sprintf(".regex(/%s/)", hexadecimalRegexString)}) + case "json": + refines = append(refines, ".refine((val) => { try { JSON.parse(val); return true } catch { return false } })") + case "jwt": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.jwt()", legacyChain: fmt.Sprintf(".regex(/%s/)", jWTRegexString)}) + case "latitude": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", latitudeRegexString)}) + case "longitude": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", longitudeRegexString)}) + case "uuid": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.uuid()", legacyChain: fmt.Sprintf(".regex(/%s/)", uUIDRegexString)}) + case "uuid3": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v3" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID3RegexString)}) + case "uuid3_rfc4122": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v3" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID3RFC4122RegexString)}) + case "uuid4": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v4" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID4RegexString)}) + case "uuid4_rfc4122": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v4" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID4RFC4122RegexString)}) + case "uuid5": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v5" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID5RegexString)}) + case "uuid5_rfc4122": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.uuid({ version: "v5" })`, legacyChain: fmt.Sprintf(".regex(/%s/)", uUID5RFC4122RegexString)}) + case "uuid_rfc4122": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: "z.uuid()", legacyChain: fmt.Sprintf(".regex(/%s/)", uUIDRFC4122RegexString)}) + case "md4": + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: fmt.Sprintf(".regex(/%s/)", md4RegexString)}) + case "md5": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.hash("md5")`, legacyChain: fmt.Sprintf(".regex(/%s/)", md5RegexString)}) + case "sha256": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.hash("sha256")`, legacyChain: fmt.Sprintf(".regex(/%s/)", sha256RegexString)}) + case "sha384": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.hash("sha384")`, legacyChain: fmt.Sprintf(".regex(/%s/)", sha384RegexString)}) + case "sha512": + chunks = append(chunks, stringSchemaChunk{kind: "format", v4Base: `z.hash("sha512")`, legacyChain: fmt.Sprintf(".regex(/%s/)", sha512RegexString)}) + default: + panic(fmt.Sprintf("unknown validation: %s", rawPart)) } } for _, refine := range refines { - validateStr.WriteString(refine) + chunks = append(chunks, stringSchemaChunk{kind: "chain", text: refine}) } - return validateStr.String() + return c.lowerStringSchemaChunks(chunks) } -func (c *Converter) preprocessValidationTagPart(part string, refines *[]string, validateStr *strings.Builder) (string, string, bool) { +func (c *Converter) lowerStringSchemaChunks(chunks []stringSchemaChunk) stringSchemaParts { + schemaParts := stringSchemaParts{} + enumIdx := -1 + firstFormatIdx := -1 + firstIPIdx := -1 + hasNonIPFormat := false + + for i, chunk := range chunks { + switch chunk.kind { + case "enum": + if enumIdx == -1 { + enumIdx = i + } + case "format": + if firstFormatIdx == -1 { + firstFormatIdx = i + } + hasNonIPFormat = true + case "ip": + if firstIPIdx == -1 { + firstIPIdx = i + } + } + } + + if enumIdx != -1 { + schemaParts.base = chunks[enumIdx].text + schemaParts.enumLike = true + for i := enumIdx + 1; i < len(chunks); i++ { + if chunks[i].kind == "chain" && strings.HasPrefix(chunks[i].text, ".refine") { + schemaParts.chain += chunks[i].text + } + } + return schemaParts + } + + if c.zodV3 { + for _, chunk := range chunks { + schemaParts.chain += legacyStringSchemaChunk(chunk) + } + return schemaParts + } + + if firstIPIdx != -1 { + if hasNonIPFormat || hasChainBeforeStringSchemaChunk(chunks, firstIPIdx) { + for _, chunk := range chunks { + schemaParts.chain += legacyStringSchemaChunk(chunk) + } + return schemaParts + } + + armChain := "" + for _, chunk := range chunks { + if chunk.kind == "chain" { + armChain += chunk.text + } + } + schemaParts.base = fmt.Sprintf("z.union([z.ipv4()%s, z.ipv6()%s])", armChain, armChain) + schemaParts.isIPUnion = true + return schemaParts + } + + if firstFormatIdx == -1 || hasChainBeforeStringSchemaChunk(chunks, firstFormatIdx) { + for _, chunk := range chunks { + schemaParts.chain += legacyStringSchemaChunk(chunk) + } + return schemaParts + } + + schemaParts.base = chunks[firstFormatIdx].v4Base + for i := firstFormatIdx + 1; i < len(chunks); i++ { + schemaParts.chain += legacyStringSchemaChunk(chunks[i]) + } + return schemaParts +} + +func legacyStringSchemaChunk(chunk stringSchemaChunk) string { + switch chunk.kind { + case "chain": + return chunk.text + case "format", "ip": + return chunk.legacyChain + default: + return "" + } +} + +func hasChainBeforeStringSchemaChunk(chunks []stringSchemaChunk, idx int) bool { + for i := 0; i < idx; i++ { + if chunks[i].kind == "chain" { + return true + } + } + return false +} + +func isPartialRecordKeySchema(schema string) bool { + schema = strings.TrimSpace(schema) + return strings.HasPrefix(schema, "z.enum(") || strings.HasPrefix(schema, "z.literal(") +} + +func (c *Converter) parseValidationTagPart(part string) (string, string, bool) { part = strings.TrimSpace(part) if part == "" { return "", "", true @@ -1123,6 +1281,15 @@ func (c *Converter) preprocessValidationTagPart(part string, refines *[]string, return "", "", true } + return valName, valValue, false +} + +func (c *Converter) preprocessValidationTagPart(part string, refines *[]string, validateStr *strings.Builder) (string, string, bool) { + valName, valValue, done := c.parseValidationTagPart(part) + if done { + return "", "", true + } + if h, ok := c.customTags[valName]; ok { v := h(c, reflect.TypeOf(0), valValue, 0) if strings.HasPrefix(v, ".refine") { diff --git a/zod_test.go b/zod_test.go index dea0857..cdc85ba 100644 --- a/zod_test.go +++ b/zod_test.go @@ -94,7 +94,7 @@ func TestStructSimpleWithOmittedField(t *testing.T) { export type User = z.infer `, - StructToZodSchema(User{})) + StructToZodSchema(User{}, WithZodV3())) } func TestStructSimplePrefix(t *testing.T) { @@ -144,7 +144,7 @@ export const UserSchema = z.object({ export type User = z.infer `, - StructToZodSchema(User{})) + StructToZodSchema(User{}, WithZodV3())) } func TestStringArray(t *testing.T) { @@ -643,7 +643,7 @@ export type Required = z.infer export type Email = z.infer `, - StructToZodSchema(Email{})) + StructToZodSchema(Email{}, WithZodV3())) type URL struct { Name string `validate:"url"` @@ -655,7 +655,7 @@ export type Email = z.infer export type URL = z.infer `, - StructToZodSchema(URL{})) + StructToZodSchema(URL{}, WithZodV3())) type IPv4 struct { Name string `validate:"ipv4"` @@ -667,7 +667,7 @@ export type URL = z.infer export type IPv4 = z.infer `, - StructToZodSchema(IPv4{})) + StructToZodSchema(IPv4{}, WithZodV3())) type IPv6 struct { Name string `validate:"ipv6"` @@ -679,7 +679,7 @@ export type IPv4 = z.infer export type IPv6 = z.infer `, - StructToZodSchema(IPv6{})) + StructToZodSchema(IPv6{}, WithZodV3())) type IP4Addr struct { Name string `validate:"ip4_addr"` @@ -691,7 +691,7 @@ export type IPv6 = z.infer export type IP4Addr = z.infer `, - StructToZodSchema(IP4Addr{})) + StructToZodSchema(IP4Addr{}, WithZodV3())) type IP6Addr struct { Name string `validate:"ip6_addr"` @@ -703,7 +703,7 @@ export type IP4Addr = z.infer export type IP6Addr = z.infer `, - StructToZodSchema(IP6Addr{})) + StructToZodSchema(IP6Addr{}, WithZodV3())) type IP struct { Name string `validate:"ip"` @@ -715,7 +715,7 @@ export type IP6Addr = z.infer export type IP = z.infer `, - StructToZodSchema(IP{})) + StructToZodSchema(IP{}, WithZodV3())) type IPAddr struct { Name string `validate:"ip_addr"` @@ -727,7 +727,7 @@ export type IP = z.infer export type IPAddr = z.infer `, - StructToZodSchema(IPAddr{})) + StructToZodSchema(IPAddr{}, WithZodV3())) type HttpURL struct { Name string `validate:"http_url"` @@ -739,7 +739,7 @@ export type IPAddr = z.infer export type HttpURL = z.infer `, - StructToZodSchema(HttpURL{})) + StructToZodSchema(HttpURL{}, WithZodV3())) type URLEncoded struct { Name string `validate:"url_encoded"` @@ -883,7 +883,7 @@ export type Uppercase = z.infer export type Base64 = z.infer `, base64RegexString), - StructToZodSchema(Base64{})) + StructToZodSchema(Base64{}, WithZodV3())) type mongodb struct { Name string `validate:"mongodb"` @@ -907,7 +907,7 @@ export type mongodb = z.infer export type datetime = z.infer `, - StructToZodSchema(datetime{})) + StructToZodSchema(datetime{}, WithZodV3())) type Hexadecimal struct { Name string `validate:"hexadecimal"` @@ -919,7 +919,7 @@ export type datetime = z.infer export type Hexadecimal = z.infer `, hexadecimalRegexString), - StructToZodSchema(Hexadecimal{})) + StructToZodSchema(Hexadecimal{}, WithZodV3())) type json struct { Name string `validate:"json"` @@ -967,7 +967,7 @@ export type Longitude = z.infer export type UUID = z.infer `, uUIDRegexString), - StructToZodSchema(UUID{})) + StructToZodSchema(UUID{}, WithZodV3())) type UUID3 struct { Name string `validate:"uuid3"` @@ -979,7 +979,7 @@ export type UUID = z.infer export type UUID3 = z.infer `, uUID3RegexString), - StructToZodSchema(UUID3{})) + StructToZodSchema(UUID3{}, WithZodV3())) type UUID3RFC4122 struct { Name string `validate:"uuid3_rfc4122"` @@ -991,7 +991,7 @@ export type UUID3 = z.infer export type UUID3RFC4122 = z.infer `, uUID3RFC4122RegexString), - StructToZodSchema(UUID3RFC4122{})) + StructToZodSchema(UUID3RFC4122{}, WithZodV3())) type UUID4 struct { Name string `validate:"uuid4"` @@ -1003,7 +1003,7 @@ export type UUID3RFC4122 = z.infer export type UUID4 = z.infer `, uUID4RegexString), - StructToZodSchema(UUID4{})) + StructToZodSchema(UUID4{}, WithZodV3())) type UUID4RFC4122 struct { Name string `validate:"uuid4_rfc4122"` @@ -1015,7 +1015,7 @@ export type UUID4 = z.infer export type UUID4RFC4122 = z.infer `, uUID4RFC4122RegexString), - StructToZodSchema(UUID4RFC4122{})) + StructToZodSchema(UUID4RFC4122{}, WithZodV3())) type UUID5 struct { Name string `validate:"uuid5"` @@ -1027,7 +1027,7 @@ export type UUID4RFC4122 = z.infer export type UUID5 = z.infer `, uUID5RegexString), - StructToZodSchema(UUID5{})) + StructToZodSchema(UUID5{}, WithZodV3())) type UUID5RFC4122 struct { Name string `validate:"uuid5_rfc4122"` @@ -1039,7 +1039,7 @@ export type UUID5 = z.infer export type UUID5RFC4122 = z.infer `, uUID5RFC4122RegexString), - StructToZodSchema(UUID5RFC4122{})) + StructToZodSchema(UUID5RFC4122{}, WithZodV3())) type UUIDRFC4122 struct { Name string `validate:"uuid_rfc4122"` @@ -1051,7 +1051,7 @@ export type UUID5RFC4122 = z.infer export type UUIDRFC4122 = z.infer `, uUIDRFC4122RegexString), - StructToZodSchema(UUIDRFC4122{})) + StructToZodSchema(UUIDRFC4122{}, WithZodV3())) type MD4 struct { Name string `validate:"md4"` @@ -1075,7 +1075,7 @@ export type MD4 = z.infer export type MD5 = z.infer `, md5RegexString), - StructToZodSchema(MD5{})) + StructToZodSchema(MD5{}, WithZodV3())) type SHA256 struct { Name string `validate:"sha256"` @@ -1087,7 +1087,7 @@ export type MD5 = z.infer export type SHA256 = z.infer `, sha256RegexString), - StructToZodSchema(SHA256{})) + StructToZodSchema(SHA256{}, WithZodV3())) type SHA384 struct { Name string `validate:"sha384"` @@ -1099,7 +1099,7 @@ export type SHA256 = z.infer export type SHA384 = z.infer `, sha384RegexString), - StructToZodSchema(SHA384{})) + StructToZodSchema(SHA384{}, WithZodV3())) type SHA512 struct { Name string `validate:"sha512"` @@ -1111,7 +1111,7 @@ export type SHA384 = z.infer export type SHA512 = z.infer `, sha512RegexString), - StructToZodSchema(SHA512{})) + StructToZodSchema(SHA512{}, WithZodV3())) type Bad2 struct { Name string `validate:"bad2"` @@ -1121,6 +1121,198 @@ export type SHA512 = z.infer }) } +func TestZodV4Defaults(t *testing.T) { + t.Run("embedded structs use shape spreads", func(t *testing.T) { + type HasID struct { + ID string + } + type HasName struct { + Name string `json:"name"` + } + type User struct { + HasID + HasName + Tags []string + } + + assert.Equal(t, `export const HasIDSchema = z.object({ + ID: z.string(), +}) +export type HasID = z.infer + +export const HasNameSchema = z.object({ + name: z.string(), +}) +export type HasName = z.infer + +export const UserSchema = z.object({ + Tags: z.string().array().nullable(), + ...HasIDSchema.shape, + ...HasNameSchema.shape, +}) +export type User = z.infer + +`, StructToZodSchema(User{})) + }) + + t.Run("string formats use zod v4 builders", func(t *testing.T) { + type Payload struct { + Email string `validate:"email"` + Link string `validate:"http_url"` + Base64 string `validate:"base64"` + ID string `validate:"uuid4"` + Checksum string `validate:"md5"` + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + Email: z.email(), + Link: z.httpUrl(), + Base64: z.base64(), + ID: z.uuid({ version: "v4" }), + Checksum: z.hash("md5"), +}) +export type Payload = z.infer + +`, StructToZodSchema(Payload{})) + }) + + t.Run("string tag order is preserved around v4 format helpers", func(t *testing.T) { + type Payload struct { + TrimmedThenEmail string `validate:"trim,email"` + EmailThenTrimmed string `validate:"email,trim"` + } + + customTagHandlers := map[string]CustomFn{ + "trim": func(c *Converter, t reflect.Type, validate string, i int) string { + return ".trim()" + }, + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + TrimmedThenEmail: z.string().trim().email(), + EmailThenTrimmed: z.email().trim(), +}) +export type Payload = z.infer + +`, NewConverterWithOpts(WithCustomTags(customTagHandlers)).Convert(Payload{})) + }) + + t.Run("ip unions inherit generic string constraints", func(t *testing.T) { + type Payload struct { + Address string `validate:"ip,required,max=45"` + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + Address: z.union([z.ipv4().min(1).refine((val) => [...val].length <= 45, 'String must contain at most 45 character(s)'), z.ipv6().min(1).refine((val) => [...val].length <= 45, 'String must contain at most 45 character(s)')]), +}) +export type Payload = z.infer + +`, StructToZodSchema(Payload{})) + }) + + t.Run("oneof takes precedence over ip specialization", func(t *testing.T) { + type Payload struct { + Address string `validate:"oneof='127.0.0.1' '::1',ip"` + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + Address: z.enum(["127.0.0.1", "::1"] as const), +}) +export type Payload = z.infer + +`, StructToZodSchema(Payload{})) + }) + + t.Run("ip mixed with another format falls back to legacy chain semantics", func(t *testing.T) { + type Payload struct { + Address string `validate:"email,ip"` + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + Address: z.string().email().ip(), +}) +export type Payload = z.infer + +`, StructToZodSchema(Payload{})) + }) + + t.Run("enum keyed maps become partial records", func(t *testing.T) { + type Payload struct { + Metadata map[string]string `validate:"dive,keys,oneof=draft published,endkeys"` + } + + assert.Equal(t, `export const PayloadSchema = z.object({ + Metadata: z.partialRecord(z.enum(["draft", "published"] as const), z.string()).nullable(), +}) +export type Payload = z.infer + +`, StructToZodSchema(Payload{})) + }) + + t.Run("recursive embedded shapes preserve encounter order for duplicate keys", func(t *testing.T) { + type Base struct { + ID string `json:"id"` + } + + type Node struct { + Base + ID int `json:"id"` + Next *Node `json:"next"` + } + + assert.Equal(t, `export const BaseSchema = z.object({ + id: z.string(), +}) +export type Base = z.infer + +export type Node = Base & { + id: number, + next: Node | null, +} +const NodeSchemaShape = { + ...BaseSchema.shape, + id: z.number(), + next: z.lazy(() => NodeSchema).nullable(), +} +export const NodeSchema: z.ZodType = z.object(NodeSchemaShape) + +`, StructToZodSchema(Node{})) + }) + + t.Run("recursive embedded shapes keep named fields before spreads", func(t *testing.T) { + type TreeNode struct { + Value string + CreatedAt time.Time + Children *[]TreeNode + } + + type Tree struct { + TreeNode + UpdatedAt time.Time + } + + assert.Equal(t, `export type TreeNode = { + Value: string, + CreatedAt: Date, + Children: TreeNode[] | null, +} +const TreeNodeSchemaShape = { + Value: z.string(), + CreatedAt: z.coerce.date(), + Children: z.lazy(() => TreeNodeSchema).array().nullable(), +} +export const TreeNodeSchema: z.ZodType = z.object(TreeNodeSchemaShape) + +export const TreeSchema = z.object({ + UpdatedAt: z.coerce.date(), + ...TreeNodeSchemaShape, +}) +export type Tree = z.infer + +`, StructToZodSchema(Tree{})) + }) +} + func TestNumberValidations(t *testing.T) { type User1 struct { Age int `validate:"gte=18,lte=60"` @@ -2182,7 +2374,7 @@ export const RequestSchema = z.object({ }).merge(SortParamsSchema.extend({field: z.enum(['title', 'address', 'age', 'dob'])})) export type Request = z.infer -`, NewConverterWithOpts(WithCustomTags(customTagHandlers)).Convert(Request{})) +`, NewConverterWithOpts(WithCustomTags(customTagHandlers), WithZodV3()).Convert(Request{})) } func TestRecursiveEmbeddedStruct(t *testing.T) { @@ -2213,7 +2405,7 @@ func TestRecursiveEmbeddedStruct(t *testing.T) { ItemE } - c := NewConverterWithOpts() + c := NewConverterWithOpts(WithZodV3()) c.AddType(ItemA{}) c.AddType(ItemB{}) c.AddType(ItemC{}) @@ -2294,7 +2486,7 @@ export const TreeSchema = z.object({ }) export type Tree = z.infer -`, StructToZodSchema(Tree{})) +`, StructToZodSchema(Tree{}, WithZodV3())) }) t.Run("embedded struct with pointer to self and date", func(t *testing.T) { @@ -2327,6 +2519,6 @@ export const ArticleSchema = z.object({ }) export type Article = z.infer -`, StructToZodSchema(Article{})) +`, StructToZodSchema(Article{}, WithZodV3())) }) }