Source code for spectree.spec

from collections import defaultdict
from copy import deepcopy
from functools import wraps
from importlib import import_module
from typing import (
    Any,
    Callable,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Type,
    get_type_hints,
)

from ._types import FunctionDecorator, ModelType, NamingStrategy, NestedNamingStrategy
from .config import Configuration, ModeEnum
from .models import Tag, ValidationError
from .plugins import PLUGINS, BasePlugin
from .response import Response
from .utils import (
    default_after_handler,
    default_before_handler,
    get_model_key,
    get_model_schema,
    get_nested_key,
    get_security,
    parse_comments,
    parse_name,
    parse_params,
    parse_request,
    parse_resp,
)


[docs] class SpecTree: """ Interface :param str backend_name: choose from ('flask', 'quart', 'falcon', 'falcon-asgi', 'starlette') :param backend: a backend that inherit `SpecTree.plugins.base.BasePlugin`, this will override the `backend_name` if provided :param app: backend framework application instance (can be registered later) :param before: a callback function of the form :meth:`spectree.utils.default_before_handler` ``func(req, resp, req_validation_error, instance)`` that will be called after the request validation before the endpoint function :param after: a callback function of the form :meth:`spectree.utils.default_after_handler` ``func(req, resp, resp_validation_error, instance)`` that will be called after the response validation :param validation_error_status: The default response status code to use in the event of a validation error. This value can be overridden for specific endpoints if needed. :param kwargs: init :class:`spectree.config.Configuration`, they can also be configured through the environment variables with prefix `spectree_` """
[docs] def __init__( self, backend_name: str = "base", backend: Optional[Type[BasePlugin]] = None, app: Any = None, before: Callable = default_before_handler, after: Callable = default_after_handler, validation_error_status: int = 422, validation_error_model: Optional[ModelType] = None, naming_strategy: NamingStrategy = get_model_key, nested_naming_strategy: NestedNamingStrategy = get_nested_key, **kwargs: Any, ): self.naming_strategy = naming_strategy self.nested_naming_strategy = nested_naming_strategy self.before = before self.after = after self.validation_error_status = validation_error_status self.validation_error_model = validation_error_model or ValidationError self.config: Configuration = Configuration.parse_obj(kwargs) self.backend_name = backend_name if backend: self.backend = backend(self) else: plugin = PLUGINS[backend_name] module = import_module(plugin.name, plugin.package) self.backend = getattr(module, plugin.class_name)(self) self.models: Dict[str, Any] = {} if app: self.register(app)
[docs] def register(self, app: Any): """ register to backend application This will be automatically triggered if the app is passed into the init step. """ self.app = app self.backend.register_route(self.app)
@property def spec(self): """ get the OpenAPI spec """ if not hasattr(self, "_spec"): self._spec = self._generate_spec() return self._spec
[docs] def bypass(self, func: Callable): """ bypass rules for routes (mode defined in config) :normal: collect all the routes exclude those decorated by other `SpecTree` instance :greedy: collect all the routes :strict: collect all the routes decorated by this instance """ if self.config.mode == ModeEnum.greedy: return False elif self.config.mode == ModeEnum.strict: return getattr(func, "_decorator", None) != self else: decorator = getattr(func, "_decorator", None) return bool(decorator and decorator != self)
[docs] def validate( # noqa: PLR0913 [too-many-arguments] self, query: Optional[ModelType] = None, json: Optional[ModelType] = None, form: Optional[ModelType] = None, headers: Optional[ModelType] = None, cookies: Optional[ModelType] = None, resp: Optional[Response] = None, tags: Sequence = (), security: Any = None, deprecated: bool = False, before: Optional[Callable] = None, after: Optional[Callable] = None, validation_error_status: int = 0, path_parameter_descriptions: Optional[Mapping[str, str]] = None, skip_validation: bool = False, operation_id: Optional[str] = None, ) -> Callable: """ - validate query, json, headers in request - validate response body and status code - add tags to this API route - add security to this API route :param query: `pydantic.BaseModel`, query in uri like `?name=value` :param json: `pydantic.BaseModel`, JSON format request body :param form: `pydantic.BaseModel`, form-data request body :param headers: `pydantic.BaseModel`, if you have specific headers :param cookies: `pydantic.BaseModel`, if you have cookies for this route :param resp: `spectree.Response` :param tags: a tuple of strings or :class:`spectree.models.Tag` :param security: dict with security config for current route and method :param deprecated: bool, if endpoint is marked as deprecated :param before: :meth:`spectree.utils.default_before_handler` for specific endpoint :param after: :meth:`spectree.utils.default_after_handler` for specific endpoint :param validation_error_status: The response status code to use for the specific endpoint, in the event of a validation error. If not specified, the global `validation_error_status` is used instead, defined in :meth:`spectree.spec.SpecTree`. :param path_parameter_descriptions: A dictionary of path parameter names and their description. :param operation_id: a string override for operationId for the given endpoint """ # If the status code for validation errors is not overridden on the level of # the view function, use the globally set status code for validation errors. if validation_error_status == 0: validation_error_status = self.validation_error_status def decorate_validation(func: Callable): # for sync framework @wraps(func) def sync_validate(*args: Any, **kwargs: Any): return self.backend.validate( func, query, json, form, headers, cookies, resp, before or self.before, after or self.after, validation_error_status, skip_validation, *args, **kwargs, ) # for async framework @wraps(func) async def async_validate(*args: Any, **kwargs: Any): return await self.backend.validate( func, query, json, form, headers, cookies, resp, before or self.before, after or self.after, validation_error_status, skip_validation, *args, **kwargs, ) validation: FunctionDecorator = ( async_validate if self.backend.ASYNC else sync_validate # type: ignore ) if self.config.annotations: nonlocal query, json, form, headers, cookies annotations = get_type_hints(func) query = annotations.get("query", query) json = annotations.get("json", json) form = annotations.get("form", form) headers = annotations.get("headers", headers) cookies = annotations.get("cookies", cookies) # register for name, model in zip( ("query", "json", "form", "headers", "cookies"), (query, json, form, headers, cookies), ): if model is not None: model_key = self._add_model(model=model) setattr(validation, name, model_key) if resp: # Make sure that the endpoint specific status code and data model for # validation errors shows up in the response spec. resp.add_model( validation_error_status, self.validation_error_model, replace=False ) for model in resp.models: self._add_model(model=model) validation.resp = resp if tags: validation.tags = tags validation.security = security validation.deprecated = deprecated validation.path_parameter_descriptions = path_parameter_descriptions validation.operation_id = operation_id # register decorator validation._decorator = self return validation return decorate_validation
def _add_model(self, model: ModelType) -> str: """ unified model processing """ model_key = self.naming_strategy(model) self.models[model_key] = deepcopy( get_model_schema( model=model, naming_strategy=self.naming_strategy, nested_naming_strategy=self.nested_naming_strategy, ) ) return model_key def _generate_spec(self) -> Dict[str, Any]: """ generate OpenAPI spec according to routes and decorators """ routes: Dict[str, Dict] = defaultdict(dict) tags = {} for route in self.backend.find_routes(): for method, func in self.backend.parse_func(route): if self.backend.bypass(func, method) or self.bypass(func): continue path_parameter_descriptions = getattr( func, "path_parameter_descriptions", None ) path, parameters = self.backend.parse_path( route, path_parameter_descriptions ) name = parse_name(func) summary, desc = parse_comments(func) func_tags = getattr(func, "tags", ()) for tag in func_tags: if str(tag) not in tags: tags[str(tag)] = ( tag.dict() if isinstance(tag, Tag) else {"name": tag} ) routes[path][method.lower()] = { "summary": summary or f"{name} <{method}>", "operationId": self.backend.get_func_operation_id( func, path, method ), "description": desc or "", "tags": [str(x) for x in getattr(func, "tags", ())], "parameters": parse_params(func, parameters[:], self.models), "responses": parse_resp(func, self.naming_strategy), } security = getattr(func, "security", None) if security is not None: routes[path][method.lower()]["security"] = get_security(security) deprecated = getattr(func, "deprecated", False) if deprecated: routes[path][method.lower()]["deprecated"] = deprecated request_body = parse_request(func) if request_body: routes[path][method.lower()]["requestBody"] = request_body spec: Dict[str, Any] = { "openapi": self.config.openapi_version, "info": self.config.openapi_info(), "tags": list(tags.values()), "paths": {**routes}, "components": { "schemas": {**self.models, **self._get_model_definitions()}, }, } if self.config.servers: spec["servers"] = [ server.dict(exclude_none=True) for server in self.config.servers ] if self.config.security_schemes: spec["components"]["securitySchemes"] = { scheme.name: scheme.data.dict(exclude_none=True, by_alias=True) for scheme in self.config.security_schemes } spec["security"] = get_security(self.config.security) return spec def _get_model_definitions(self) -> Dict[str, Any]: """ handle nested models """ definitions = {} for name, schema in self.models.items(): if "definitions" in schema: for key, value in schema["definitions"].items(): composed_key = self.nested_naming_strategy(name, key) if composed_key not in definitions: definitions[composed_key] = value del schema["definitions"] return definitions