Merge pull request #412 from pierotofy/secret

Add shared storage option
This commit is contained in:
Piero Toffanin 2023-03-09 23:14:45 -05:00 committed by GitHub
commit 246ad84419
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 305 additions and 60 deletions

View file

@ -192,6 +192,7 @@ docker-compose -f docker-compose.cuda.yml up -d --build
| --get-api-key-link | Show a link in the UI where to direct users to get an API key | `Don't show a link` | LT_GET_API_KEY_LINK |
| --require-api-key-origin | Require use of an API key for programmatic access to the API, unless the request origin matches this domain | `No restrictions on domain origin` | LT_REQUIRE_API_KEY_ORIGIN |
| --require-api-key-secret | Require use of an API key for programmatic access to the API, unless the client also sends a secret match | `No secrets required` | LT_REQUIRE_API_KEY_SECRET |
| --shared-storage | Shared storage URI to use for multi-process data sharing (e.g. when using gunicorn) | `memory://` | LT_SHARED_STORAGE |
| --load-only | Set available languages | `all from argostranslate` | LT_LOAD_ONLY |
| --threads | Set number of threads | `4` | LT_THREADS |
| --suggestions | Allow user suggestions | `False` | LT_SUGGESTIONS |

View file

@ -21,7 +21,7 @@ from werkzeug.exceptions import HTTPException
from werkzeug.http import http_date
from flask_babel import Babel
from libretranslate import flood, remove_translated_files, security
from libretranslate import scheduler, flood, secret, remove_translated_files, security, storage
from libretranslate.language import detect_languages, improve_translation_formatting
from libretranslate.locales import (_, _lazy, get_available_locales, get_available_locale_codes, gettext_escaped,
gettext_html, lazy_swag, get_alternate_locale_links)
@ -127,6 +127,8 @@ def create_app(args):
bp = Blueprint('Main app', __name__)
storage.setup(args.shared_storage)
if not args.disable_files_translation:
remove_translated_files.setup(get_upload_dir())
languages = load_languages()
@ -202,8 +204,12 @@ def create_app(args):
limiter = Limiter()
if args.req_flood_threshold > 0:
flood.setup(args.req_flood_threshold)
if not "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""):
# Gunicorn starts the scheduler in the master process
scheduler.setup(args)
flood.setup(args)
secret.setup(args)
measure_request = None
gauge_request = None
@ -261,16 +267,16 @@ def create_app(args):
if (args.require_api_key_secret
and key_missing
and not flood.secret_match(get_req_secret())
and not secret.secret_match(get_req_secret())
):
need_key = True
if need_key:
description = _("Please contact the server operator to get an API key")
if args.get_api_key_link:
description = _("Visit %(url)s to get an API key", url=args.get_api_key_link)
abort(
403,
400,
description=description,
)
return f(*a, **kw)
@ -347,7 +353,7 @@ def create_app(args):
response = Response(render_template("app.js.template",
url_prefix=args.url_prefix,
get_api_key_link=args.get_api_key_link,
api_secret=flood.get_current_secret() if args.require_api_key_secret else ""), content_type='application/javascript; charset=utf-8')
api_secret=secret.get_current_secret() if args.require_api_key_secret else ""), content_type='application/javascript; charset=utf-8')
if args.require_api_key_secret:
response.headers['Last-Modified'] = http_date(datetime.now())

View file

@ -136,6 +136,11 @@ _default_options_objects = [
'default_value': False,
'value_type': 'bool'
},
{
'name': 'SHARED_STORAGE',
'default_value': 'memory://',
'value_type': 'str'
},
{
'name': 'LOAD_ONLY',
'default_value': None,

View file

@ -1,75 +1,47 @@
import atexit
import random
import string
from libretranslate.storage import get_storage
from apscheduler.schedulers.background import BackgroundScheduler
def generate_secret():
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
banned = {}
active = False
threshold = -1
secrets = [generate_secret(), generate_secret()]
def forgive_banned():
global banned
global threshold
clear_list = []
s = get_storage()
banned = s.get_all_hash_int("banned")
for ip in banned:
if banned[ip] <= 0:
clear_list.append(ip)
else:
banned[ip] = min(threshold, banned[ip]) - 1
s.set_hash_int("banned", ip, min(threshold, banned[ip]) - 1)
for ip in clear_list:
del banned[ip]
s.del_hash("banned", ip)
def rotate_secrets():
global secrets
secrets[0] = secrets[1]
secrets[1] = generate_secret()
def secret_match(s):
return s in secrets
def get_current_secret():
return secrets[1]
def setup(violations_threshold=100):
def setup(args):
global active
global threshold
active = True
threshold = violations_threshold
scheduler = BackgroundScheduler()
scheduler.add_job(func=forgive_banned, trigger="interval", minutes=30)
scheduler.add_job(func=rotate_secrets, trigger="interval", minutes=30)
scheduler.start()
# Shut down the scheduler when exiting the app
atexit.register(lambda: scheduler.shutdown())
if args.req_flood_threshold > 0:
active = True
threshold = args.req_flood_threshold
def report(request_ip):
if active:
banned[request_ip] = banned.get(request_ip, 0)
banned[request_ip] += 1
get_storage().inc_hash_int("banned", request_ip)
def decrease(request_ip):
if banned[request_ip] > 0:
banned[request_ip] -= 1
s = get_storage()
if s.get_hash_int("banned", request_ip) > 0:
s.dec_hash_int("banned", request_ip)
def has_violation(request_ip):
return request_ip in banned and banned[request_ip] > 0
s = get_storage()
return s.get_hash_int("banned", request_ip) > 0
def is_banned(request_ip):
s = get_storage()
# More than X offences?
return active and banned.get(request_ip, 0) >= threshold
return active and s.get_hash_int("banned", request_ip) >= threshold

View file

@ -126,6 +126,13 @@ def get_args():
action="store_true",
help="Require use of an API key for programmatic access to the API, unless the client also sends a secret match",
)
parser.add_argument(
"--shared-storage",
type=str,
default=DEFARGS['SHARED_STORAGE'],
metavar="<Storage URI>",
help="Shared storage URI to use for multi-process data sharing (e.g. via gunicorn)",
)
parser.add_argument(
"--load-only",
type=operator.methodcaller("split", ","),

View file

@ -0,0 +1,23 @@
import atexit
from apscheduler.schedulers.background import BackgroundScheduler
scheduler = None
def setup(args):
from libretranslate.flood import forgive_banned
from libretranslate.secret import rotate_secrets
global scheduler
if scheduler is None:
scheduler = BackgroundScheduler()
if args.req_flood_threshold > 0:
scheduler.add_job(func=forgive_banned, trigger="interval", minutes=10)
if args.api_keys and args.require_api_key_secret:
scheduler.add_job(func=rotate_secrets, trigger="interval", minutes=30)
scheduler.start()
# Shut down the scheduler when exiting the app
atexit.register(lambda: scheduler.shutdown())

28
libretranslate/secret.py Normal file
View file

@ -0,0 +1,28 @@
import atexit
import random
import string
from libretranslate.storage import get_storage
def generate_secret():
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
def rotate_secrets():
s = get_storage()
secret_1 = s.get_str("secret_1")
s.set_str("secret_0", secret_1)
s.set_str("secret_1", generate_secret())
def secret_match(secret):
s = get_storage()
return secret == s.get_str("secret_0") or secret == s.get_str("secret_1")
def get_current_secret():
return get_storage().get_str("secret_1")
def setup(args):
if args.api_keys and args.require_api_key_secret:
s = get_storage()
s.set_str("secret_0", generate_secret())
s.set_str("secret_1", generate_secret())

158
libretranslate/storage.py Normal file
View file

@ -0,0 +1,158 @@
import redis
storage = None
def get_storage():
return storage
class Storage:
def set_bool(self, key, value):
raise Exception("not implemented")
def get_bool(self, key):
raise Exception("not implemented")
def set_int(self, key, value):
raise Exception("not implemented")
def get_int(self, key):
raise Exception("not implemented")
def set_str(self, key, value):
raise Exception("not implemented")
def get_str(self, key):
raise Exception("not implemented")
def set_hash_int(self, ns, key, value):
raise Exception("not implemented")
def get_hash_int(self, ns, key):
raise Exception("not implemented")
def inc_hash_int(self, ns, key):
raise Exception("not implemented")
def dec_hash_int(self, ns, key):
raise Exception("not implemented")
def get_hash_keys(self, ns):
raise Exception("not implemented")
def del_hash(self, ns, key):
raise Exception("not implemented")
class MemoryStorage(Storage):
def __init__(self):
self.store = {}
def set_bool(self, key, value):
self.store[key] = bool(value)
def get_bool(self, key):
return bool(self.store[key])
def set_int(self, key, value):
self.store[key] = int(value)
def get_int(self, key):
return int(self.store.get(key, 0))
def set_str(self, key, value):
self.store[key] = value
def get_str(self, key):
return str(self.store.get(key, ""))
def set_hash_int(self, ns, key, value):
if ns not in self.store:
self.store[ns] = {}
self.store[ns][key] = int(value)
def get_hash_int(self, ns, key):
d = self.store.get(ns, {})
return int(d.get(key, 0))
def inc_hash_int(self, ns, key):
if ns not in self.store:
self.store[ns] = {}
if key not in self.store[ns]:
self.store[ns][key] = 0
else:
self.store[ns][key] += 1
def dec_hash_int(self, ns, key):
if ns not in self.store:
self.store[ns] = {}
if key not in self.store[ns]:
self.store[ns][key] = 0
else:
self.store[ns][key] -= 1
def get_all_hash_int(self, ns):
if ns in self.store:
return [{str(k): int(v)} for k,v in self.store[ns].items()]
else:
return []
def del_hash(self, ns, key):
del self.store[ns][key]
class RedisStorage(Storage):
def __init__(self, redis_uri):
self.conn = redis.from_url(redis_uri)
self.conn.ping()
def set_bool(self, key, value):
self.conn.set(key, "1" if value else "0")
def get_bool(self, key):
return bool(self.conn.get(key))
def set_int(self, key, value):
self.conn.set(key, str(value))
def get_int(self, key):
v = self.conn.get(key)
if v is None:
return 0
else:
return v
def set_str(self, key, value):
self.conn.set(key, value)
def get_str(self, key):
v = self.conn.get(key)
if v is None:
return ""
else:
return v.decode('utf-8')
def get_hash_int(self, ns, key):
v = self.conn.hget(ns, key)
if v is None:
return 0
else:
return int(v)
def set_hash_int(self, ns, key, value):
self.conn.hset(ns, key, value)
def inc_hash_int(self, ns, key):
return int(self.conn.hincrby(ns, key))
def dec_hash_int(self, ns, key):
return int(self.conn.hincrby(ns, key, -1))
def get_all_hash_int(self, ns):
return {k.decode("utf-8"): int(v) for k,v in self.conn.hgetall(ns).items()}
def del_hash(self, ns, key):
self.conn.hdel(ns, key)
def setup(storage_uri):
global storage
if storage_uri.startswith("memory://"):
storage = MemoryStorage()
elif storage_uri.startswith("redis://"):
storage = RedisStorage(storage_uri)
else:
raise Exception("Invalid storage URI: " + storage_uri)
return storage

View file

@ -243,9 +243,8 @@ document.addEventListener('DOMContentLoaded', function(){
request.onload = function() {
try{
{% if api_secret != "" %}
if (this.status === 403){
window.location.reload(true);
return;
if (this.status === 400){
if (self.refreshOnce()) return;
}
{% endif %}
@ -362,6 +361,15 @@ document.addEventListener('DOMContentLoaded', function(){
this.translatedFileUrl = false;
this.loadingFileTranslation = false;
},
refreshOnce: function(){
var lastRefreshed = parseInt(localStorage.getItem("refreshed") || 0);
var now = new Date().getTime();
if (now - lastRefreshed > 1000 * 60 * 1){
localStorage.setItem("refreshed", now);
window.location.reload();
return true;
}
},
translateFile: function(e) {
e.preventDefault();
@ -383,9 +391,8 @@ document.addEventListener('DOMContentLoaded', function(){
if (translateFileRequest.readyState === 4 && translateFileRequest.status === 200) {
try{
{% if api_secret != "" %}
if (this.status === 403){
window.location.reload(true);
return;
if (this.status === 400){
if (self.refreshOnce()) return;
}
{% endif %}
self.loadingFileTranslation = false;

View file

@ -1,4 +1,42 @@
from prometheus_client import multiprocess
import re
import sys
def child_exit(server, worker):
multiprocess.mark_process_dead(worker.pid)
multiprocess.mark_process_dead(worker.pid)
def on_starting(server):
# Parse command line arguments
proc_name = server.cfg.default_proc_name
kwargs = {}
if proc_name.startswith("wsgi:app"):
str_args = re.sub('wsgi:app\s*\(\s*(.*)\s*\)', '\\1', proc_name).strip().split(",")
for a in str_args:
if "=" in a:
k,v = a.split("=")
k = k.strip()
v = v.strip()
if v.lower() in ["true", "false"]:
v = v.lower() == "true"
elif v[0] == '"':
v = v[1:-1]
kwargs[k] = v
from libretranslate.main import get_args
sys.argv = ['--wsgi']
for k in kwargs:
ck = k.replace("_", "-")
if isinstance(kwargs[k], bool) and kwargs[k]:
sys.argv.append("--" + ck)
else:
sys.argv.append("--" + ck)
sys.argv.append(kwargs[k])
args = get_args()
from libretranslate import storage, scheduler, flood, secret
storage.setup(args.shared_storage)
scheduler.setup(args)
flood.setup(args)
secret.setup(args)