Merge pull request #2963 from jderuiter/mypy-importers

Type annotations for bookwyrm.importers
This commit is contained in:
Mouse Reeve 2023-08-21 20:39:32 -07:00 committed by GitHub
commit e5f8e4babc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 28 deletions

View file

@ -1,4 +1,6 @@
""" handle reading a csv from calibre """
from typing import Any, Optional
from bookwyrm.models import Shelf
from . import Importer
@ -9,7 +11,7 @@ class CalibreImporter(Importer):
service = "Calibre"
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
# Add timestamp to row_mappings_guesses for date_added to avoid
# integrity error
row_mappings_guesses = []
@ -23,6 +25,6 @@ class CalibreImporter(Importer):
self.row_mappings_guesses = row_mappings_guesses
super().__init__(*args, **kwargs)
def get_shelf(self, normalized_row):
def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
# Calibre export does not indicate which shelf to use. Use a default one for now
return Shelf.TO_READ

View file

@ -1,8 +1,10 @@
""" handle reading a csv from an external service, defaults are from Goodreads """
import csv
from datetime import timedelta
from typing import Iterable, Optional
from django.utils import timezone
from bookwyrm.models import ImportJob, ImportItem, SiteSettings
from bookwyrm.models import ImportJob, ImportItem, SiteSettings, User
class Importer:
@ -35,19 +37,26 @@ class Importer:
}
# pylint: disable=too-many-locals
def create_job(self, user, csv_file, include_reviews, privacy):
def create_job(
self, user: User, csv_file: Iterable[str], include_reviews: bool, privacy: str
) -> ImportJob:
"""check over a csv and creates a database entry for the job"""
csv_reader = csv.DictReader(csv_file, delimiter=self.delimiter)
rows = list(csv_reader)
if len(rows) < 1:
raise ValueError("CSV file is empty")
rows = enumerate(rows)
mappings = (
self.create_row_mappings(list(fieldnames))
if (fieldnames := csv_reader.fieldnames)
else {}
)
job = ImportJob.objects.create(
user=user,
include_reviews=include_reviews,
privacy=privacy,
mappings=self.create_row_mappings(csv_reader.fieldnames),
mappings=mappings,
source=self.service,
)
@ -55,16 +64,20 @@ class Importer:
if enforce_limit and allowed_imports <= 0:
job.complete_job()
return job
for index, entry in rows:
for index, entry in enumerate(rows):
if enforce_limit and index >= allowed_imports:
break
self.create_item(job, index, entry)
return job
def update_legacy_job(self, job):
def update_legacy_job(self, job: ImportJob) -> None:
"""patch up a job that was in the old format"""
items = job.items
headers = list(items.first().data.keys())
first_item = items.first()
if first_item is None:
return
headers = list(first_item.data.keys())
job.mappings = self.create_row_mappings(headers)
job.updated_date = timezone.now()
job.save()
@ -75,24 +88,24 @@ class Importer:
item.normalized_data = normalized
item.save()
def create_row_mappings(self, headers):
def create_row_mappings(self, headers: list[str]) -> dict[str, Optional[str]]:
"""guess what the headers mean"""
mappings = {}
for (key, guesses) in self.row_mappings_guesses:
value = [h for h in headers if h.lower() in guesses]
value = value[0] if len(value) else None
values = [h for h in headers if h.lower() in guesses]
value = values[0] if len(values) else None
if value:
headers.remove(value)
mappings[key] = value
return mappings
def create_item(self, job, index, data):
def create_item(self, job: ImportJob, index: int, data: dict[str, str]) -> None:
"""creates and saves an import item"""
normalized = self.normalize_row(data, job.mappings)
normalized["shelf"] = self.get_shelf(normalized)
ImportItem(job=job, index=index, data=data, normalized_data=normalized).save()
def get_shelf(self, normalized_row):
def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
"""determine which shelf to use"""
shelf_name = normalized_row.get("shelf")
if not shelf_name:
@ -103,11 +116,15 @@ class Importer:
]
return shelf[0] if shelf else None
def normalize_row(self, entry, mappings): # pylint: disable=no-self-use
# pylint: disable=no-self-use
def normalize_row(
self, entry: dict[str, str], mappings: dict[str, Optional[str]]
) -> dict[str, Optional[str]]:
"""use the dataclass to create the formatted row of data"""
return {k: entry.get(v) for k, v in mappings.items()}
return {k: entry.get(v) if v else None for k, v in mappings.items()}
def get_import_limit(self, user): # pylint: disable=no-self-use
# pylint: disable=no-self-use
def get_import_limit(self, user: User) -> tuple[int, int]:
"""check if import limit is set and return how many imports are left"""
site_settings = SiteSettings.objects.get()
import_size_limit = site_settings.import_size_limit
@ -125,7 +142,9 @@ class Importer:
allowed_imports = import_size_limit - imported_books
return enforce_limit, allowed_imports
def create_retry_job(self, user, original_job, items):
def create_retry_job(
self, user: User, original_job: ImportJob, items: list[ImportItem]
) -> ImportJob:
"""retry items that didn't import"""
job = ImportJob.objects.create(
user=user,

View file

@ -1,11 +1,16 @@
""" handle reading a tsv from librarything """
import re
from typing import Optional
from bookwyrm.models import Shelf
from . import Importer
def _remove_brackets(value: Optional[str]) -> Optional[str]:
return re.sub(r"\[|\]", "", value) if value else None
class LibrarythingImporter(Importer):
"""csv downloads from librarything"""
@ -13,16 +18,19 @@ class LibrarythingImporter(Importer):
delimiter = "\t"
encoding = "ISO-8859-1"
def normalize_row(self, entry, mappings): # pylint: disable=no-self-use
def normalize_row(
self, entry: dict[str, str], mappings: dict[str, Optional[str]]
) -> dict[str, Optional[str]]: # pylint: disable=no-self-use
"""use the dataclass to create the formatted row of data"""
remove_brackets = lambda v: re.sub(r"\[|\]", "", v) if v else None
normalized = {k: remove_brackets(entry.get(v)) for k, v in mappings.items()}
isbn_13 = normalized.get("isbn_13")
isbn_13 = isbn_13.split(", ") if isbn_13 else []
normalized = {
k: _remove_brackets(entry.get(v) if v else None)
for k, v in mappings.items()
}
isbn_13 = value.split(", ") if (value := normalized.get("isbn_13")) else []
normalized["isbn_13"] = isbn_13[1] if len(isbn_13) > 1 else None
return normalized
def get_shelf(self, normalized_row):
def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
if normalized_row["date_finished"]:
return Shelf.READ_FINISHED
if normalized_row["date_started"]:

View file

@ -1,4 +1,6 @@
""" handle reading a csv from openlibrary"""
from typing import Any
from . import Importer
@ -7,7 +9,7 @@ class OpenLibraryImporter(Importer):
service = "OpenLibrary"
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
self.row_mappings_guesses.append(("openlibrary_key", ["edition id"]))
self.row_mappings_guesses.append(("openlibrary_work_key", ["work id"]))
super().__init__(*args, **kwargs)

View file

@ -54,10 +54,10 @@ ImportStatuses = [
class ImportJob(models.Model):
"""entry for a specific request for book data import"""
user = models.ForeignKey(User, on_delete=models.CASCADE)
user: User = models.ForeignKey(User, on_delete=models.CASCADE)
created_date = models.DateTimeField(default=timezone.now)
updated_date = models.DateTimeField(default=timezone.now)
include_reviews = models.BooleanField(default=True)
include_reviews: bool = models.BooleanField(default=True)
mappings = models.JSONField()
source = models.CharField(max_length=100)
privacy = models.CharField(max_length=255, default="public", choices=PrivacyLevels)
@ -76,7 +76,7 @@ class ImportJob(models.Model):
self.save(update_fields=["task_id"])
def complete_job(self):
def complete_job(self) -> None:
"""Report that the job has completed"""
self.status = "complete"
self.complete = True

View file

@ -13,6 +13,9 @@ implicit_reexport = True
[mypy-bookwyrm.connectors.*]
ignore_errors = False
[mypy-bookwyrm.importers.*]
ignore_errors = False
[mypy-celerywyrm.*]
ignore_errors = False