Skip to content

Commit 278b546

Browse files
authored
Merge pull request #120 from foodszhang/ref_resolver
ref resolver
2 parents ab8e170 + 9972ae5 commit 278b546

File tree

9 files changed

+101
-73
lines changed

9 files changed

+101
-73
lines changed

swagger_py_codegen/jsonschema.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from inspect import getsource
66

77
from .base import Code, CodeGenerator
8-
from .parser import schema_var_name
8+
from .parser import RefNode
99

1010

1111
class Schema(Code):
@@ -89,10 +89,8 @@ def build_data(swagger):
8989
scopes[(endpoint, method)] = list(security.values()).pop()
9090
break
9191

92-
schemas = OrderedDict([(schema_var_name(path), swagger.get(path)) for path in swagger.definitions])
93-
9492
data = dict(
95-
schemas=schemas,
93+
definitions={'definitions':swagger.origin_data.get('definitions', {})},
9694
validators=validators,
9795
filters=filters,
9896
scopes=scopes,
@@ -109,7 +107,7 @@ def _process(self):
109107
yield Schema(build_data(self.swagger))
110108

111109

112-
def merge_default(schema, value, get_first=True):
110+
def merge_default(schema, value, get_first=True, resolver=None):
113111
# TODO: more types support
114112
type_defaults = {
115113
'integer': 9573,
@@ -119,17 +117,17 @@ def merge_default(schema, value, get_first=True):
119117
'boolean': False
120118
}
121119

122-
results = normalize(schema, value, type_defaults)
120+
results = normalize(schema, value, type_defaults, resolver=resolver)
123121
if get_first:
124122
return results[0]
125123
return results
126124

127125

128-
def build_default(schema):
129-
return merge_default(schema, None)
126+
def build_default(schema, resolver=None):
127+
return merge_default(schema, None, resolver=resolver)
130128

131129

132-
def normalize(schema, data, required_defaults=None):
130+
def normalize(schema, data, required_defaults=None, resolver=None):
133131
if required_defaults is None:
134132
required_defaults = {}
135133
errors = []
@@ -217,7 +215,7 @@ def _normalize_dict(schema, data):
217215

218216
def _normalize_list(schema, data):
219217
result = []
220-
if hasattr(data, '__iter__') and not isinstance(data, dict):
218+
if hasattr(data, '__iter__') and not isinstance(data, (dict, RefNode)):
221219
for item in data:
222220
result.append(_normalize(schema.get('items'), item))
223221
elif 'default' in schema:
@@ -230,6 +228,15 @@ def _normalize_default(schema, data):
230228
else:
231229
return data
232230

231+
def _normalize_ref(schema, data):
232+
if resolver == None:
233+
raise TypeError("resolver must be provided")
234+
ref = schema.get(u"$ref")
235+
scope, resolved = resolver.resolve(ref)
236+
return _normalize(resolved, data)
237+
238+
239+
233240
def _normalize(schema, data):
234241
if schema is True or schema == {}:
235242
return data
@@ -239,10 +246,13 @@ def _normalize(schema, data):
239246
'object': _normalize_dict,
240247
'array': _normalize_list,
241248
'default': _normalize_default,
249+
'ref': _normalize_ref
242250
}
243251
type_ = schema.get('type', 'object')
244252
if type_ not in funcs:
245253
type_ = 'default'
254+
if schema.get(u'$ref', None):
255+
type_ = 'ref'
246256

247257
return funcs[type_](schema, data)
248258

swagger_py_codegen/parser.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,36 @@ def schema_var_name(path):
1313
return ''.join(map(str.capitalize, map(str, path)))
1414

1515

16-
class RefNode(dict):
16+
class RefNode(object):
1717

1818
def __init__(self, data, ref):
1919
self.ref = ref
20-
super(RefNode, self).__init__(data)
20+
self._data = data
21+
22+
23+
def __getitem__(self, key):
24+
return self._data.__getitem__(key)
25+
26+
def __setitem__(self, key, value):
27+
return self._data.__setitem__(key, value)
28+
29+
def __getattr__(self, key):
30+
return self._data.__getattribute__(key)
31+
32+
def __iter__(self):
33+
return self._data.__iter__()
2134

2235
def __repr__(self):
23-
return schema_var_name(self.ref)
36+
return repr({'$ref':self.ref})
37+
38+
def __eq__(self, other):
39+
if isinstance(other, RefNode):
40+
return self._data == other._data and self.ref == other.ref
41+
else:
42+
return object.__eq__(other)
2443

44+
def copy(self):
45+
return RefNode(self._data, self.ref)
2546

2647
class Swagger(object):
2748

@@ -40,14 +61,10 @@ def _process_ref(self):
4061
"""
4162
resolve all references util no reference exists
4263
"""
43-
while 1:
44-
li = list(self.search(['**', '$ref']))
45-
if not li:
46-
break
47-
for path, ref in li:
48-
data = resolve(self.data, ref)
49-
path = path[:-1]
50-
self.set(path, data)
64+
for path, ref in self.search(['**', '$ref']):
65+
data = resolve(self.data, ref)
66+
path = path[:-1]
67+
self.set(path, RefNode(data, ref))
5168

5269
def _resolve_definitions(self):
5370
"""
@@ -76,17 +93,19 @@ def get_definition_refs():
7693
while definition_refs:
7794
ready = {
7895
definition for definition, refs
79-
in six.iteritems(definition_refs) if not refs
96+
in six.iteritems(definition_refs)
8097
}
8198
if not ready:
82-
msg = '$ref circular references found!\n'
83-
raise ValueError(msg)
99+
continue
100+
#msg = '$ref circular references found!\n'
101+
#raise ValueError(msg)
84102
for definition in ready:
85103
del definition_refs[definition]
86104
for refs in six.itervalues(definition_refs):
87105
refs.difference_update(ready)
88106

89107
self._definitions += ready
108+
self._definitions.sort(key=lambda x :x[1])
90109

91110
def search(self, path):
92111
for p, d in dpath.util.search(

swagger_py_codegen/templates/falcon/validators.tpl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from werkzeug.datastructures import MultiDict, Headers
1414
from jsonschema import Draft4Validator
1515

1616
from .schemas import (
17-
validators, filters, scopes, security, base_path, normalize)
17+
validators, filters, scopes, resolver, security, base_path, normalize)
1818

1919

2020
if six.PY3:
@@ -44,7 +44,7 @@ class JSONEncoder(json.JSONEncoder):
4444
class FalconValidatorAdaptor(object):
4545

4646
def __init__(self, schema):
47-
self.validator = Draft4Validator(schema)
47+
self.validator = Draft4Validator(schema, resolver=resolver)
4848

4949
def validate_number(self, type_, value):
5050
try:
@@ -87,7 +87,7 @@ class FalconValidatorAdaptor(object):
8787
def validate(self, value):
8888
value = self.type_convert(value)
8989
errors = {e.path[0]: e.message for e in self.validator.iter_errors(value)}
90-
return normalize(self.validator.schema, value)[0], errors
90+
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
9191

9292

9393
def request_validate(req, resp, resource, params):
@@ -154,15 +154,15 @@ def response_filter(req, resp, resource):
154154
'Not defined',
155155
description='`%d` is not a defined status code.' % status)
156156

157-
_resp, errors = normalize(schemas['schema'], req.context['result'])
157+
_resp, errors = normalize(schemas['schema'], req.context['result'], resolver=resolver)
158158
if schemas['headers']:
159159
headers, header_errors = normalize(
160-
{'properties': schemas['headers']}, headers)
160+
{'properties': schemas['headers']}, headers, resolver=resolver)
161161
errors.extend(header_errors)
162162
if errors:
163163
raise falcon.HTTPInternalServerError(title='Expectation Failed',
164164
description=errors)
165165

166166
if 'result' not in req.context:
167167
return
168-
resp.body = json.dumps(_resp)
168+
resp.body = json.dumps(_resp)

swagger_py_codegen/templates/flask/validators.tpl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ from flask_restful.utils import unpack
1515
from jsonschema import Draft4Validator
1616

1717
from .schemas import (
18-
validators, filters, scopes, security, merge_default, normalize)
18+
validators, filters, scopes, resolver, security, merge_default, normalize)
1919

2020

2121
class JSONEncoder(json.JSONEncoder):
@@ -29,7 +29,7 @@ class JSONEncoder(json.JSONEncoder):
2929
class FlaskValidatorAdaptor(object):
3030

3131
def __init__(self, schema):
32-
self.validator = Draft4Validator(schema)
32+
self.validator = Draft4Validator(schema, resolver=resolver)
3333

3434
def validate_number(self, type_, value):
3535
try:
@@ -72,7 +72,7 @@ class FlaskValidatorAdaptor(object):
7272
def validate(self, value):
7373
value = self.type_convert(value)
7474
errors = list(e.message for e in self.validator.iter_errors(value))
75-
return normalize(self.validator.schema, value)[0], errors
75+
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
7676

7777

7878
def request_validate(view):
@@ -136,10 +136,10 @@ def response_filter(view):
136136
# return resp, status, headers
137137
abort(500, message='`%d` is not a defined status code.' % status)
138138

139-
resp, errors = normalize(schemas['schema'], resp)
139+
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
140140
if schemas['headers']:
141141
headers, header_errors = normalize(
142-
{'properties': schemas['headers']}, headers)
142+
{'properties': schemas['headers']}, headers, resolver=resolver)
143143
errors.extend(header_errors)
144144
if errors:
145145
abort(500, message='Expectation Failed', errors=errors)

swagger_py_codegen/templates/jsonschema/schemas.tpl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# -*- coding: utf-8 -*-
22

33
import six
4+
from jsonschema import RefResolver
5+
from swagger_py_codegen.parser import RefNode
46

57
# TODO: datetime support
68

9+
710
{% include '_do_not_change.tpl' %}
811

912
base_path = '{{base_path}}'
1013

11-
{% for name, value in schemas.items() %}
12-
{{ name }} = {{ value }}
13-
{%- endfor %}
14+
definitions = {{ definitions }}
1415

1516
validators = {
1617
{%- for name, value in validators.items() %}
@@ -30,6 +31,7 @@ scopes = {
3031
{%- endfor %}
3132
}
3233

34+
resolver = RefResolver.from_schema(definitions)
3335

3436
class Security(object):
3537

swagger_py_codegen/templates/sanic/validators.tpl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from sanic.request import RequestParameters
1717
from jsonschema import Draft4Validator
1818

1919
from .schemas import (
20-
validators, filters, scopes, security, base_path, normalize, current)
20+
validators, filters, scopes, security, resolver, base_path, normalize, current)
2121

2222

2323
def unpack(value):
@@ -63,7 +63,7 @@ class JSONEncoder(json.JSONEncoder):
6363
class SanicValidatorAdaptor(object):
6464

6565
def __init__(self, schema):
66-
self.validator = Draft4Validator(schema)
66+
self.validator = Draft4Validator(schema, resolver=resolver)
6767

6868
def validate_number(self, type_, value):
6969
try:
@@ -106,7 +106,7 @@ class SanicValidatorAdaptor(object):
106106
def validate(self, value):
107107
value = self.type_convert(value)
108108
errors = list(e.message for e in self.validator.iter_errors(value))
109-
return normalize(self.validator.schema, value)[0], errors
109+
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
110110

111111

112112
def request_validate(view):
@@ -175,10 +175,10 @@ def response_filter(view):
175175
# return resp, status, headers
176176
raise ServerError('`%d` is not a defined status code.' % status, 500)
177177

178-
resp, errors = normalize(schemas['schema'], resp)
178+
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
179179
if schemas['headers']:
180180
headers, header_errors = normalize(
181-
{'properties': schemas['headers']}, headers)
181+
{'properties': schemas['headers']}, headers, resolver=resolver)
182182
errors.extend(header_errors)
183183
if errors:
184184
raise ServerError('Expectation Failed', 500)

swagger_py_codegen/templates/tornado/validators.tpl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import six
1111
from functools import wraps
1212
from jsonschema import Draft4Validator
1313

14-
from .schemas import validators, scopes, normalize, filters
14+
from .schemas import validators, scopes, resolver, normalize, filters
1515

1616

1717
class ValidatorAdaptor(object):
1818

1919
def __init__(self, schema):
20-
self.validator = Draft4Validator(schema)
20+
self.validator = Draft4Validator(schema, resolver=resolver)
2121

2222
def validate_number(self, type_, value):
2323
try:
@@ -66,7 +66,7 @@ class ValidatorAdaptor(object):
6666
def validate(self, value):
6767
value = self.type_convert(value)
6868
errors = list(e.message for e in self.validator.iter_errors(value))
69-
return normalize(self.validator.schema, value)[0], errors
69+
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
7070

7171
def request_validate(obj):
7272
def _request_validate(view):
@@ -134,10 +134,10 @@ def response_filter(obj):
134134
raise tornado.web.HTTPError(
135135
500, message='`%d` is not a defined status code.' % status)
136136

137-
resp, errors = normalize(schemas['schema'], resp)
137+
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
138138
if schemas['headers']:
139139
headers, header_errors = normalize(
140-
{'properties': schemas['headers']}, headers)
140+
{'properties': schemas['headers']}, headers, resolver=resolver)
141141
errors.extend(header_errors)
142142
if errors:
143143
raise tornado.web.HTTPError(
@@ -167,4 +167,4 @@ def unpack(value):
167167
except ValueError:
168168
pass
169169

170-
return value, 200, {}
170+
return value, 200, {}

0 commit comments

Comments
 (0)