allow batch processing for language detection

This commit is contained in:
mammo0 2021-03-11 10:52:38 +01:00
parent d4cb859c8d
commit 36fee9bf1b

View file

@ -9,12 +9,25 @@ __lang_codes = [l.code for l in languages]
def detect_languages(text): def detect_languages(text):
f = Detector(text).languages # detect batch processing
if isinstance(text, list):
is_batch = True
else:
is_batch = False
text = [text]
# get the candidates # get the candidates
candidate_langs = list(filter(lambda l: l.read_bytes != 0 and l.code in __lang_codes, f)) candidates = []
for t in text:
candidates.extend(Detector(t).languages)
# this happens if no language can be detected # total read bytes of the provided text
read_bytes_total = sum(c.read_bytes for c in candidates)
# only use candidates that are supported by argostranslate
candidate_langs = list(filter(lambda l: l.read_bytes != 0 and l.code in __lang_codes, candidates))
# this happens if no language could be detected
if not candidate_langs: if not candidate_langs:
# use language "en" by default but with zero confidence # use language "en" by default but with zero confidence
return [ return [
@ -24,8 +37,29 @@ def detect_languages(text):
} }
] ]
# for multiple occurrences of the same language (can happen on batch detection)
# calculate the average confidence for each language
if is_batch:
temp_average_list = []
for lang_code in __lang_codes:
# get all candidates for a specific language
lc = list(filter(lambda l: l.code == lang_code, candidate_langs))
if len(lc) > 1:
# if more than one is present, calculate the average confidence
lang = lc[0]
lang.confidence = sum(l.confidence for l in lc) / len(lc)
lang.read_bytes = sum(l.read_bytes for l in lc)
temp_average_list.append(lang)
elif lc:
# otherwise just add it to the temporary list
temp_average_list.append(lc[0])
if temp_average_list:
# replace the list
candidate_langs = temp_average_list
# sort the candidates descending based on the detected confidence # sort the candidates descending based on the detected confidence
candidate_langs.sort(key=lambda l: l.confidence, reverse=True) candidate_langs.sort(key=lambda l: (l.confidence * l.read_bytes) / read_bytes_total, reverse=True)
return [ return [
{ {