diff --git a/metabolomics_spectrum_resolver/app.py b/metabolomics_spectrum_resolver/app.py index 223c01a..1f9e714 100644 --- a/metabolomics_spectrum_resolver/app.py +++ b/metabolomics_spectrum_resolver/app.py @@ -1,12 +1,40 @@ +import ipaddress import os +import flask from flask import Flask - -from metabolomics_spectrum_resolver import views +from flask_limiter import Limiter APP_ROOT = os.path.dirname(os.path.realpath(__file__)) +WHITELISTED_RANGES = [ + ipaddress.ip_network("138.23.0.0/16"), # UCR + ipaddress.ip_network("169.235.0.0/16"), # UCR + ipaddress.ip_network("132.239.0.0/16"), # UCSD + ipaddress.ip_network("137.110.0.0/16"), # UCSD + ipaddress.ip_network("192.31.146.0/24"), # UCSD +] + + +def get_ip(): + if flask.request.headers.getlist("X-Forwarded-For"): + ip = flask.request.headers.getlist("X-Forwarded-For")[0] + else: + ip = flask.request.remote_addr + return ip.split(",")[0].strip() + + +def get_ip_or_exempt(): + try: + client_ip = ipaddress.ip_address(get_ip()) + for network in WHITELISTED_RANGES: + if client_ip in network: + return "whitelisted-user" + except ValueError: + pass + return get_ip() + class CustomFlask(Flask): jinja_options = Flask.jinja_options.copy() @@ -24,4 +52,14 @@ class CustomFlask(Flask): app = CustomFlask(__name__) app.config.from_object(__name__) + +limiter = Limiter( + key_func=get_ip_or_exempt, + app=app, + default_limits=[], + storage_uri="redis://metabolomicsusi-redis:6379", +) + +# Import views after limiter is created to avoid circular import +from metabolomics_spectrum_resolver import views # noqa: E402 app.register_blueprint(views.blueprint) diff --git a/metabolomics_spectrum_resolver/views.py b/metabolomics_spectrum_resolver/views.py index 7be9950..682916b 100644 --- a/metabolomics_spectrum_resolver/views.py +++ b/metabolomics_spectrum_resolver/views.py @@ -11,6 +11,7 @@ from spectrum_utils import spectrum as sus from metabolomics_spectrum_resolver import similarity, tasks +from metabolomics_spectrum_resolver.app import limiter from metabolomics_spectrum_resolver.error import UsiError @@ -73,6 +74,7 @@ def mirror_forward(): @blueprint.route("/png/", methods=["GET", "POST"]) +@limiter.limit("10/minute") def generate_png(): request_params = flask.request.values.to_dict() @@ -96,6 +98,7 @@ def generate_png(): @blueprint.route("/png/mirror/", methods=["GET", "POST"]) +@limiter.limit("10/minute") def generate_mirror_png(): request_params = flask.request.values.to_dict() @@ -123,6 +126,7 @@ def generate_mirror_png(): @blueprint.route("/svg/", methods=["GET", "POST"]) +@limiter.limit("10/minute") def generate_svg(): request_params = flask.request.values.to_dict() @@ -146,6 +150,7 @@ def generate_svg(): @blueprint.route("/svg/mirror/", methods=["GET", "POST"]) +@limiter.limit("10/minute") def generate_mirror_svg(): request_params = flask.request.values.to_dict() @@ -470,6 +475,7 @@ def _prepare_mirror_spectra( @blueprint.route("/json/") +@limiter.limit("10/minute") def peak_json(): try: spectrum, _, splash_key = tasks.parse_usi( @@ -497,6 +503,7 @@ def peak_json(): @blueprint.route("/json/mirror/") +@limiter.limit("10/minute") def mirror_json(): try: drawing_controls = get_drawing_controls( diff --git a/requirements.txt b/requirements.txt index 0431c50..829f617 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ celery_once dash dash-bootstrap-components flask +Flask-Limiter flex joblib locust