Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions openapi_specgen/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import dataclasses
import datetime
import enum
import inspect
import typing

Expand All @@ -18,6 +19,8 @@
float: "number",
int: "integer",
bool: "boolean",
enum.IntEnum: "integer",
enum.Enum: "string"
}

OPENAPI_FORMAT_MAP: typing.Dict[type, str] = {
Expand Down Expand Up @@ -95,6 +98,15 @@ def resolve_any(openapi_schema_resolver: "OpenApiSchemaResolver", data_type: typ
return {"$ref": openapi_schema_resolver.get_component_ref("AnyValue")}


def resolve_enum(openapi_schema_resolver: "OpenApiSchemaResolver", data_type: type):
if not isinstance(data_type, enum.EnumMeta):
return

openapi_type = OPENAPI_TYPE_MAP.get(data_type.__base__)

return {"type": openapi_type}
Comment on lines +101 to +107
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def resolve_enum(openapi_schema_resolver: "OpenApiSchemaResolver", data_type: type):
if not isinstance(data_type, enum.EnumMeta):
return
openapi_type = OPENAPI_TYPE_MAP.get(data_type.__base__)
return {"type": openapi_type}
def resolve_enum(openapi_schema_resolver: "OpenApiSchemaResolver", data_type: type):
if not isinstance(data_type, enum.EnumMeta):
return
openapi_schema = openapi_schema_resolver.get_schema(data_type.__base__)
openapi_schema["enum"] = [member.value for member in data_type]
return openapi_schema

if you can apply these changes I'll merge, test and release



class ResolverProto(typing.Protocol):

def __call__(
Expand All @@ -113,6 +125,7 @@ class OpenApiSchemaResolver:
def __init__(self) -> None:
self._resolvers: typing.List[ResolverProto] = [
resolve_basic,
resolve_enum,
resolve_array,
resolve_mapping,
resolve_dataclass,
Expand Down
50 changes: 48 additions & 2 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

from unittest.mock import create_autospec

from openapi_specgen import (ApiKeyAuth, BearerAuth, OpenApi, OpenApiParam,
OpenApiPath, OpenApiResponse, OpenApiSecurity)
from openapi_specgen.schema import OpenApiSchemaResolver

from .utils import DataclassNestedObject, MarshmallowSchema
from .utils import DataclassNestedObject, MarshmallowSchema, DataclassEnum


def test_openapi_with_dataclass():
Expand Down Expand Up @@ -87,6 +86,53 @@ def test_openapi_with_dataclass():
assert expected_openapi_dict == test_api.as_dict()


def test_openapi_with_enum():
expected_openapi_dict = {
'components': {'schemas': {'DataclassEnum': {'properties': {'any_enum_field': {'type': 'string'},
'boolean_field': {'type': 'boolean'},
'date_field': {'format': 'date',
'type': 'string'},
'datetime_field': {'format': 'date-time',
'type': 'string'},
'float_field': {'type': 'number'},
'int_enum_field': {'type': 'integer'},
'int_enum_field3': {'type': 'integer'},
'int_field': {'type': 'integer'},
'list_field': {'items': {},
'type': 'array'},
'str_field': {'type': 'string'}},
'required': ['str_field',
'int_field',
'float_field',
'boolean_field',
'list_field',
'date_field',
'datetime_field',
'any_enum_field',
'int_enum_field',
'int_enum_field3'],
'title': 'DataclassEnum',
'type': 'object'}}},
'info': {'title': 'test_api', 'version': '3.0.2'},
'openapi': '3.0.2',
'paths': {'/test_path': {'get': {'description': '',
'operationId': '[get]_/test_path',
'parameters': [{'in': 'query',
'name': 'test_param',
'required': True,
'schema': {'title': 'Test_Param',
'type': 'string'}}],
'responses': {'200': {'content': {'application/json': {
'schema': {'$ref': '#/components/schemas/DataclassEnum'}}},
'description': 'test_response'}},
'summary': ''}}}}
test_resp = OpenApiResponse('test_response', data_type=DataclassEnum)
test_param = OpenApiParam('test_param', 'query', data_type=str)
test_path = OpenApiPath('/test_path', 'get', [test_resp], [test_param])
test_api = OpenApi('test_api', [test_path])
assert expected_openapi_dict == test_api.as_dict()


def test_openapi_with_marshmallow():
expected_openapi_dict = {
'openapi': '3.0.2',
Expand Down
23 changes: 23 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import enum
from dataclasses import dataclass
from datetime import date, datetime
from typing import List

from marshmallow import Schema, fields


class AnyEnum(str, enum.Enum):
STR = 'STR'


class IntEnum(enum.IntEnum):
FIRST = enum.auto()


@dataclass
class DataclassObject():
str_field: str
Expand All @@ -22,6 +31,20 @@ class DataclassNestedObject():
nested_object: DataclassObject


@dataclass
class DataclassEnum():
str_field: str
int_field: int
float_field: float
boolean_field: bool
list_field: List
date_field: date
datetime_field: datetime
any_enum_field: AnyEnum
int_enum_field: IntEnum
int_enum_field3: IntEnum


class MarshmallowSchema(Schema):
str_field = fields.String(required=True)
int_field = fields.Integer()
Expand Down