Merge pull request #317 from dingedi/main

move improve_translation in language.py and use it in transliteration
This commit is contained in:
Piero Toffanin 2022-09-23 22:06:13 -04:00 committed by GitHub
commit 052f0ae1cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 34 deletions

View file

@ -15,7 +15,7 @@ from translatehtml import translate_html
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from app import flood, remove_translated_files, security from app import flood, remove_translated_files, security
from app.language import detect_languages, transliterate from app.language import detect_languages, transliterate, improve_translation_formatting
from .api_keys import Database, RemoteDatabase from .api_keys import Database, RemoteDatabase
from .suggestions import Database as SuggestionsDatabase from .suggestions import Database as SuggestionsDatabase
@ -483,46 +483,15 @@ def create_app(args):
if text_format not in ["text", "html"]: if text_format not in ["text", "html"]:
abort(400, description="%s format is not supported" % text_format) abort(400, description="%s format is not supported" % text_format)
def improve_translation(source, translation):
source = source.strip()
source_last_char = source[len(source) - 1]
translation_last_char = translation[len(translation) - 1]
punctuation_chars = ['!', '?', '.', ',', ';']
if source_last_char in punctuation_chars:
if translation_last_char != source_last_char:
if translation_last_char in punctuation_chars:
translation = translation[:-1]
translation += source_last_char
elif translation_last_char in punctuation_chars:
translation = translation[:-1]
if source.islower():
return translation.lower()
if source.isupper():
return translation.upper()
if source[0].islower():
return translation[0].lower() + translation[1:]
if source[0].isupper():
return translation[0].upper() + translation[1:]
return translation
try: try:
if batch: if batch:
results = [] results = []
for idx, text in enumerate(q): for idx, text in enumerate(q):
translator = src_langs[idx].get_translation(tgt_lang) translator = src_langs[idx].get_translation(tgt_lang)
if text_format == "html": if text_format == "html":
translated_text = str(translate_html(translator, text)) translated_text = str(translate_html(translator, text))
else: else:
translated_text = improve_translation(text, translator.translate( translated_text = improve_translation_formatting(text, translator.translate(
transliterate(text, target_lang=source_langs[idx]["language"]))) transliterate(text, target_lang=source_langs[idx]["language"])))
results.append(unescape(translated_text)) results.append(unescape(translated_text))
@ -545,7 +514,7 @@ def create_app(args):
if text_format == "html": if text_format == "html":
translated_text = str(translate_html(translator, q)) translated_text = str(translate_html(translator, q))
else: else:
translated_text = improve_translation(q, translator.translate( translated_text = improve_translation_formatting(q, translator.translate(
transliterate(q, target_lang=source_langs[0]["language"]))) transliterate(q, target_lang=source_langs[0]["language"])))
if source_lang == "auto": if source_lang == "auto":

View file

@ -79,6 +79,41 @@ def detect_languages(text):
return [{"confidence": l.confidence, "language": l.code} for l in candidate_langs] return [{"confidence": l.confidence, "language": l.code} for l in candidate_langs]
def improve_translation_formatting(source, translation, improve_punctuation=True):
source = source.strip()
if not len(source):
return ""
if improve_punctuation:
source_last_char = source[len(source) - 1]
translation_last_char = translation[len(translation) - 1]
punctuation_chars = ['!', '?', '.', ',', ';']
if source_last_char in punctuation_chars:
if translation_last_char != source_last_char:
if translation_last_char in punctuation_chars:
translation = translation[:-1]
translation += source_last_char
elif translation_last_char in punctuation_chars:
translation = translation[:-1]
if source.islower():
return translation.lower()
if source.isupper():
return translation.upper()
if source[0].islower():
return translation[0].lower() + translation[1:]
if source[0].isupper():
return translation[0].upper() + translation[1:]
return translation
def __transliterate_line(transliterator, line_text): def __transliterate_line(transliterator, line_text):
new_text = [] new_text = []
@ -98,6 +133,8 @@ def __transliterate_line(transliterator, line_text):
if not t_word: if not t_word:
t_word = orig_word t_word = orig_word
else: else:
t_word = improve_translation_formatting(orig_word.strip(string.punctuation), t_word, improve_punctuation=False)
# add back any stripped punctuation # add back any stripped punctuation
if r_diff: if r_diff:
t_word = t_word + "".join(r_diff) t_word = t_word + "".join(r_diff)