Skip to content

Commit f501981

Browse files
authored
Merge pull request #122 from sctjgz/tjgz_swagger_impove
add one feature and fix some bugs
2 parents e7df322 + 996345f commit f501981

File tree

5 files changed

+103
-45
lines changed

5 files changed

+103
-45
lines changed

swagger_py_codegen/command.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import absolute_import
2-
from os import path
2+
from os import path, listdir
33
import codecs
4+
5+
from swagger_py_codegen.jsonschema import Schema
6+
47
try:
58
import simplejson as json
69
except ImportError:
710
import json
811
from os import makedirs
9-
from os.path import join, exists, dirname
12+
from os.path import join, exists, dirname, isdir
1013

1114
import six
1215
import yaml
@@ -29,8 +32,7 @@ def get_ref_filepath(filename, ref_file):
2932
return ref_file
3033

3134

32-
def spec_load(filename):
33-
spec_data = {}
35+
def get_loader(filename):
3436
if filename.endswith('.json'):
3537
loader = json.load
3638
elif filename.endswith('.yml') or filename.endswith('.yaml'):
@@ -43,20 +45,67 @@ def spec_load(filename):
4345
loader = json.load
4446
else:
4547
loader = yaml.load
48+
return loader
49+
50+
51+
def load_file(filename, spec_data):
52+
loader = get_loader(filename)
4653
with codecs.open(filename, 'r', 'utf-8') as f:
4754
data = loader(f)
55+
# modify_spec_data(spec_data, data)
4856
spec_data.update(data)
4957
for field, values in six.iteritems(data):
5058
if field not in ['definitions', 'parameters', 'paths']:
5159
continue
5260
if not isinstance(values, dict):
5361
continue
5462
for _field, value in six.iteritems(values):
63+
# _field is the endpoint for paths when the api of paths contains $ref
5564
if _field == '$ref' and value.endswith('.yml'):
5665
_filepath = get_ref_filepath(filename, value)
5766
field_data = spec_load(_filepath)
67+
# modify_spec_data(field, spec_data, field_data)
5868
spec_data[field] = field_data
59-
return spec_data
69+
elif '$ref' in value:
70+
v = value.pop('$ref', '')
71+
if not v:
72+
continue
73+
_filepath = get_ref_filepath(filename, v)
74+
field_data = spec_load(_filepath)
75+
modify_spec_data(field, spec_data, field_data)
76+
# spec_data[field][_field] = field_data.values()
77+
78+
79+
def dump_file(filename, data):
80+
if not filename.endswith('.yml'):
81+
return None
82+
if not exists(filename):
83+
dirs = filename.rsplit('/', 1)[0]
84+
# dirs, fn = '/'.join(paths[:-1]), paths[-1]
85+
if not exists(dirs):
86+
makedirs(dirs)
87+
with codecs.open(filename, 'w', 'utf-8') as f:
88+
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
89+
90+
91+
def modify_spec_data(field, spec_data, data):
92+
if not isinstance(spec_data, dict) or not isinstance(data, dict):
93+
return None
94+
for k, v in data.items():
95+
if k in spec_data[field]:
96+
spec_data[field][k].update(v)
97+
else:
98+
spec_data[field][k] = v
99+
100+
101+
def spec_load(filename):
102+
spec_data = {}
103+
files = listdir(filename) if isdir(filename) else [filename]
104+
for f in files:
105+
if f != filename:
106+
f = filename + '/' + f
107+
load_file(f, spec_data)
108+
return spec_data
60109

61110

62111
def write(dist, content):
@@ -124,6 +173,7 @@ def generate(destination, swagger_doc, force=False, package=None,
124173
click.echo("Validation passed")
125174
except ValidationError as e:
126175
raise click.ClickException(str(e))
176+
#print 'data ',data
127177
swagger = Swagger(data)
128178
if templates == 'tornado':
129179
generator = TornadoGenerator(swagger)

swagger_py_codegen/flask.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,59 +10,50 @@
1010

1111

1212
class Router(Code):
13-
1413
template = 'flask/routers.tpl'
1514
dest_template = '%(package)s/%(module)s/routes.py'
1615
override = True
1716

1817

1918
class View(Code):
20-
2119
template = 'flask/view.tpl'
2220
dest_template = '%(package)s/%(module)s/api/%(view)s.py'
2321
override = False
2422

2523

2624
class Specification(Code):
27-
2825
template = 'flask/specification.tpl'
2926
dest_template = '%(package)s/static/%(module)s/swagger.json'
3027
override = True
3128

3229

3330
class Validator(Code):
34-
3531
template = 'flask/validators.tpl'
3632
dest_template = '%(package)s/%(module)s/validators.py'
3733
override = True
3834

3935

4036
class Api(Code):
41-
4237
template = 'flask/api.tpl'
4338
dest_template = '%(package)s/%(module)s/api/__init__.py'
4439

4540

4641
class Blueprint(Code):
47-
4842
template = 'flask/blueprint.tpl'
4943
dest_template = '%(package)s/%(module)s/__init__.py'
5044

5145

5246
class App(Code):
53-
5447
template = 'flask/app.tpl'
5548
dest_template = '%(package)s/__init__.py'
5649

5750

5851
class Requirements(Code):
59-
6052
template = 'flask/requirements.tpl'
6153
dest_template = 'requirements.txt'
6254

6355

6456
class UIIndex(Code):
65-
6657
template = 'ui/index.html'
6758
dest_template = '%(package)s/static/swagger-ui/index.html'
6859

@@ -86,13 +77,14 @@ def _type(parameters):
8677
if t in types:
8778
yield '<%s>' % p['name'], '<%s:%s>' % (types[t], p['name'])
8879

89-
for old, new in _type(node.get('parameters', [])):
90-
url = url.replace(old, new)
80+
for method, param in six.iteritems(node):
81+
for old, new in _type(param.get('parameters', [])):
82+
url = url.replace(old, new)
9183

92-
for k in SUPPORT_METHODS:
93-
if k in node:
94-
for old, new in _type(node[k].get('parameters', [])):
95-
url = url.replace(old, new)
84+
for k in SUPPORT_METHODS:
85+
if k in param:
86+
for old, new in _type(param[k].get('parameters', [])):
87+
url = url.replace(old, new)
9688

9789
return url, params
9890

@@ -126,7 +118,6 @@ def _location(swagger_location):
126118

127119

128120
class FlaskGenerator(CodeGenerator):
129-
130121
dependencies = [SchemaGenerator]
131122

132123
def __init__(self, swagger):

swagger_py_codegen/jsonschema.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import absolute_import
22

3+
import copy
4+
35
import six
46
from collections import OrderedDict
57
from inspect import getsource
@@ -9,7 +11,6 @@
911

1012

1113
class Schema(Code):
12-
1314
template = 'jsonschema/schemas.tpl'
1415
dest_template = '%(package)s/%(module)s/schemas.py'
1516
override = True
@@ -27,8 +28,9 @@ def _parameters_to_schemas(params):
2728
# schema is required `in` is `body`
2829
yield location, param['schema']
2930
continue
30-
31-
prop = param.copy()
31+
# prop = param.copy()
32+
# If the parameter is referanced more than once,it would be format only once.
33+
prop = copy.deepcopy(param)
3234
prop.pop('in')
3335
if param.get('required'):
3436
required.append(param['name'])
@@ -41,7 +43,6 @@ def _parameters_to_schemas(params):
4143

4244

4345
def build_data(swagger):
44-
4546
validators = OrderedDict() # (endpoint, method) = {'body': schema_name or schema, 'query': schema_name, ..}
4647
filters = OrderedDict() # (endpoint, method) = {'200': {'schema':, 'headers':, 'examples':}, 'default': ..}
4748
scopes = OrderedDict() # (endpoint, method) = [scope_a, scope_b]
@@ -56,19 +57,21 @@ def build_data(swagger):
5657

5758
# methods
5859
for p, data in swagger.search(path + ('*',)):
60+
5961
if p[-1] not in ['get', 'post', 'put', 'delete', 'patch', 'options', 'head']:
6062
continue
6163
method_param = []
64+
6265
try:
6366
method_param = swagger.get(p + ('parameters',))
6467
except KeyError:
6568
pass
6669

6770
endpoint = p[1] # p: ('paths', '/some/path', 'method')
6871
method = p[-1].upper()
69-
7072
# parameters as schema
7173
validator = dict(_parameters_to_schemas(path_param + method_param))
74+
#print 'parameters:::::::::::::', path_param, endpoint, method, validator, method_param
7275
if validator:
7376
validators[(endpoint, method)] = validator
7477

@@ -88,9 +91,9 @@ def build_data(swagger):
8891
for security in data.get('security', []):
8992
scopes[(endpoint, method)] = list(security.values()).pop()
9093
break
91-
9294
data = dict(
93-
definitions={'definitions':swagger.origin_data.get('definitions', {})},
95+
definitions={'definitions': swagger.origin_data.get('definitions', {}),
96+
'parameters': swagger.origin_data.get('parameters', {})},
9497
validators=validators,
9598
filters=filters,
9699
scopes=scopes,
@@ -194,7 +197,7 @@ def _normalize_dict(schema, data):
194197

195198
# get value
196199
value, has_key = data.get_check(key)
197-
if has_key:
200+
if has_key or '$ref' in _schema:
198201
result[key] = _normalize(_schema, value)
199202
elif 'default' in _schema:
200203
result[key] = _schema['default']
@@ -233,10 +236,10 @@ def _normalize_ref(schema, data):
233236
raise TypeError("resolver must be provided")
234237
ref = schema.get(u"$ref")
235238
scope, resolved = resolver.resolve(ref)
239+
if resolved.get('nullable', False) and not data:
240+
return {}
236241
return _normalize(resolved, data)
237242

238-
239-
240243
def _normalize(schema, data):
241244
if schema is True or schema == {}:
242245
return data

swagger_py_codegen/templates/flask/validators.tpl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class FlaskValidatorAdaptor(object):
6161
for k, values in obj.lists():
6262
prop = self.validator.schema['properties'].get(k, {})
6363
type_ = prop.get('type')
64+
if type_ is None and '$ref' in prop:
65+
ref = prop.get('$ref')
66+
if not ref:
67+
continue
68+
type_ = self.validator.resolver.resolve(prop.get('$ref'))[1].get('type')
69+
if not type_:
70+
continue
6471
fun = convert_funs.get(type_, lambda v: v[0])
6572
if type_ == 'array':
6673
item_type = prop.get('items', {}).get('type')

tests/test_flask.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@
88
FlaskGenerator
99
)
1010

11+
1112
def test_swagger_to_flask_url():
1213
cases = [
1314
{
1415
'url': '/users/{id}',
1516
'data': {
16-
'parameters': [{
17-
'name': 'id',
18-
'in': 'path',
19-
'type': 'integer'
20-
}],
2117
'get': {
2218
'parameters': [{
2319
'name': 'limit',
2420
'in': 'query',
2521
'type': 'integer'
26-
}]
22+
},
23+
{
24+
'name': 'id',
25+
'in': 'path',
26+
'type': 'integer'
27+
}
28+
]
2729
},
2830
'post': {
2931
'parameters': [{
@@ -34,7 +36,13 @@ def test_swagger_to_flask_url():
3436
'name': {'type': 'string'}
3537
}
3638
}
37-
}]
39+
},
40+
{
41+
'name': 'id',
42+
'in': 'path',
43+
'type': 'integer'
44+
}
45+
]
3846
}
3947
},
4048
'expect': (
@@ -58,13 +66,12 @@ def test_swagger_to_flask_url():
5866
'name': 'price',
5967
'in': 'path',
6068
'type': 'float'
69+
}, {
70+
'name': 'category',
71+
'in': 'path',
72+
'type': 'integer'
6173
}]
6274
},
63-
'parameters': [{
64-
'name': 'category',
65-
'in': 'path',
66-
'type': 'integer'
67-
}]
6875
},
6976
'expect': (
7077
'/goods/categories/<int:category>/price-large-than/<float:price>/order-by/<order>',
@@ -81,7 +88,8 @@ def test_swagger_to_flask_url():
8188
}
8289
]
8390
for case in cases:
84-
assert _swagger_to_flask_url(case['url'], case['data']) == case['expect']
91+
a = _swagger_to_flask_url(case['url'], case['data'])
92+
assert a == case['expect']
8593

8694

8795
def test_path_to_endpoint():
@@ -121,7 +129,6 @@ def test_process_data():
121129
'get': {},
122130
'put': {},
123131
'head': {},
124-
'parameters': []
125132
},
126133
'/posts/{post_id}': {
127134
'get': {

0 commit comments

Comments
 (0)