diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5edff18..ecc4edf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,12 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v3.2.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.3.0 hooks: - id: black + exclude: ^dist/ diff --git a/flake.nix b/flake.nix index be90c25..adf86a7 100644 --- a/flake.nix +++ b/flake.nix @@ -45,6 +45,17 @@ poetry pkgs.pre-commit sqlite + + mypy + + # additional python interpreters for use with tox + #pkgs.python37 + #pkgs.python37Packages.virtualenv + pkgs.python38 + pkgs.python38Packages.virtualenv + pkgs.python39 + pkgs.python39Packages.virtualenv + tox ]; }); })); diff --git a/pyproject.toml b/pyproject.toml index 4f661f2..7042148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Composable tracking of scientific data provenance" authors = ["Jacob Hinkle "] [tool.poetry.dependencies] -python = "^3.7" +python = "^3.8" click = "^8.1.3" colorama = "^0.4.5" loguru = "^0.6.0" @@ -24,3 +24,23 @@ nancy = "nancy.cli:main" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.tox] +legacy_tox_ini = """ +[tox] +envlist = py38,py39,py310,mypy +isolated_build = true + +[testenv] +deps = + pytest + pytest-cov + coverage +commands = + pytest --cov src/nancy + +[testenv:mypy] +deps = mypy +commands = + mypy --strict -p nancy +""" diff --git a/src/nancy/cli/__init__.py b/src/nancy/cli/__init__.py index 8954882..0673616 100644 --- a/src/nancy/cli/__init__.py +++ b/src/nancy/cli/__init__.py @@ -7,10 +7,12 @@ from ..version import __version__ from . import diff from . import record +from typing import Optional + # from https://click.palletsprojects.com/en/5.x/advanced/ class AliasedGroup(click.Group): - def get_command(self, ctx, cmd_name): + def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: rv = click.Group.get_command(self, ctx, cmd_name) if rv is not None: return rv @@ -19,11 +21,11 @@ class AliasedGroup(click.Group): return None elif len(matches) == 1: return click.Group.get_command(self, ctx, matches[0]) - ctx.fail("Too many matches: %s" % ", ".join(sorted(matches))) + return ctx.fail("Too many matches: %s" % ", ".join(sorted(matches))) @click.command() -def version(): +def version() -> None: """Print version information.""" print(f"nancy v{__version__}") @@ -41,7 +43,7 @@ def version(): default="SUCCESS", help="If given, print all output including debugging info.", ) -def main(log_level): +def main(log_level: str) -> None: import sys logger.remove() diff --git a/src/nancy/cli/common.py b/src/nancy/cli/common.py index 4a64414..4da0a3f 100644 --- a/src/nancy/cli/common.py +++ b/src/nancy/cli/common.py @@ -1,4 +1,7 @@ -def confirm(question, default_no=False): +def confirm( + question: str, + default_no: bool = False, +) -> bool: """Ask a question and wait for a Y/N response.""" choices = " [y/N]: " if default_no else " [Y/n]: " diff --git a/src/nancy/cli/diff.py b/src/nancy/cli/diff.py index 984d562..4afcfdf 100644 --- a/src/nancy/cli/diff.py +++ b/src/nancy/cli/diff.py @@ -10,8 +10,12 @@ import warnings def print_diff( - ABdiff: fs.FSDiff, indent=2, indent_level=0, use_color=True, show_hashes=False -): + ABdiff: fs.FSDiff, + indent: int = 2, + indent_level: int = 0, + use_color: bool = True, + show_hashes: bool = False, +) -> None: """Pretty print an FSDiff object""" if use_color: try: @@ -33,7 +37,7 @@ def print_diff( reset = Style.RESET_ALL if use_color else "" hashcolor = Fore.MAGENTA if use_color else "" - def _print_row(tag, entry, level): + def _print_row(tag: str, entry: fs.FSEntry, level: int) -> None: relpath = entry.relpath # Format relpath using filetype-based colors @@ -95,7 +99,7 @@ def print_diff( "store is initialized there.", ) @logger.catch -def status(show_hashes, no_color, store): +def status(show_hashes: bool, no_color: bool, store: str) -> None: """Detect and describe changes to PATH PATH is a path to a file or directory inside an existing nancy store diff --git a/src/nancy/cli/record.py b/src/nancy/cli/record.py index a4a4bd5..b333e05 100644 --- a/src/nancy/cli/record.py +++ b/src/nancy/cli/record.py @@ -8,17 +8,18 @@ from .diff import print_diff import os import sys +from typing import Any, Optional, Union @logger.catch def record( - message, - store_path=None, - show_diff=True, - show_hashes=False, - use_color=True, - skip_confirm=False, -): + message: str, + store_path: Optional[Union[str, "os.PathLike[Any]"]] = None, + show_diff: bool = True, + show_hashes: bool = False, + use_color: bool = True, + skip_confirm: bool = False, +) -> None: """Unwrapped record command""" if store_path is None: diff --git a/src/nancy/db.py b/src/nancy/db.py index d48c65c..690d376 100644 --- a/src/nancy/db.py +++ b/src/nancy/db.py @@ -1,5 +1,7 @@ import importlib.resources +import os import sqlite3 +from typing import Any, Union import warnings @@ -24,7 +26,7 @@ if ( ) -def init_schema(cur): +def init_schema(cur: sqlite3.Cursor) -> None: """Initialize a database following the current schema.""" schema = importlib.resources.read_text( "nancy.schema", @@ -33,7 +35,7 @@ def init_schema(cur): cur.executescript(schema) -def connect(path): +def connect(path: "os.PathLike[Any]") -> sqlite3.Connection: conn = sqlite3.connect(path) conn.cursor().execute("PRAGMA foreign_keys = ON;") return conn diff --git a/src/nancy/environment.py b/src/nancy/environment.py index 429c9d9..5035299 100644 --- a/src/nancy/environment.py +++ b/src/nancy/environment.py @@ -1,14 +1,19 @@ from . import user -from typing import NamedTuple import json import os import platform +import sqlite3 import sys +from typing import NamedTuple, Optional, TypeVar, Type + + +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_EnvironmentT = TypeVar("_EnvironmentT", bound="Environment") class Environment(NamedTuple): - id: int + id: Optional[int] envvars_json: str python_implementation: str python_strversion: str @@ -16,14 +21,18 @@ class Environment(NamedTuple): user: user.User @classmethod - def find_or_insert(cls, cur, env=None): + def find_or_insert( + cls: Type[_EnvironmentT], + cur: sqlite3.Cursor, + env: Optional[_EnvironmentT] = None, + ) -> _EnvironmentT: """Given a DB cursor, find or create row in environment table and fill""" if env is None: env = cls.detect() u = user.User.find_or_insert(cur) - env = env._replace(user=u.id) + env = env._replace(user=u) # insert or ignore, handle each case to set id cur.execute( @@ -40,7 +49,13 @@ class Environment(NamedTuple): user = ? LIMIT 1 """, - env[1:], + ( + env.envvars_json, + env.python_implementation, + env.python_strversion, + env.python_hexversion, + env.user.id, + ), ) res = cur.fetchone() if res is None: @@ -48,7 +63,14 @@ class Environment(NamedTuple): """ INSERT INTO environment VALUES (?,?,?,?,?,?); """, - env, + ( + env.id, + env.envvars_json, + env.python_implementation, + env.python_strversion, + env.python_hexversion, + env.user.id, + ), ) id = cur.lastrowid cur.connection.commit() @@ -58,7 +80,7 @@ class Environment(NamedTuple): return env._replace(id=id) @classmethod - def detect(cls): + def detect(cls: Type[_EnvironmentT]) -> _EnvironmentT: """Detect values for environment independent of the database. Note that the user entry will not have a valid id. @@ -71,5 +93,5 @@ class Environment(NamedTuple): platform.python_implementation(), sys.version, sys.hexversion, - u.id, + u, ) diff --git a/src/nancy/fs.py b/src/nancy/fs.py index 6e22c5c..e55771f 100644 --- a/src/nancy/fs.py +++ b/src/nancy/fs.py @@ -4,15 +4,21 @@ from loguru import logger from dataclasses import dataclass from datetime import datetime +from enum import Enum import hashlib import operator import os +from pathlib import Path +import sqlite3 import stat -from typing import List +from typing import Any, AnyStr, List, Optional, Tuple, TypeVar, Type, Union import warnings -def remove_write_perms(path): +PathStr = Union[str, Path, "os.PathLike[str]"] + + +def remove_write_perms(path: PathStr) -> Optional[str]: """Remove write permissions for all users while preserving other perms""" if not os.path.islink(path): s = os.stat(path) @@ -49,15 +55,18 @@ def remove_write_perms(path): return orig_perm_string -def make_readonly_recursive(path, excluded=[]): +def make_readonly_recursive( + path: PathStr, + excluded: List[PathStr] = [], +) -> None: """Recursively "freeze" a directory by setting all files and directories read-only""" # traversing bottom-up makes it easier to freeze perms on directories - for root, dirs, files in os.walk(path, topdown=False): + for root, dirs, files in os.walk(str(path), topdown=False): for f in files: p = os.path.join(root, f) if p in excluded: continue - remove_write_perms(os.path.join(path, p)) + remove_write_perms(os.path.join(Path(path), p)) for d in dirs: p = os.path.join(root, d) @@ -66,58 +75,98 @@ def make_readonly_recursive(path, excluded=[]): remove_write_perms(os.path.join(path, p)) +class FileType(Enum): + """One of 'LNK', 'DIR', 'REG', etc. + + names are compatible with those used in the `stat` module. + + See :meth:'store.FSEntry.from_path' for details. + """ + + BLK = "BLK" + CHR = "CHR" + DIR = "DIR" + DOOR = "DOOR" + FIFO = "FIFO" + LNK = "LNK" + OTHER = "OTHER" + REG = "REG" + PORT = "PORT" + SOCK = "SOCK" + WHT = "WHT" + + +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_FSEntryVersionT = TypeVar("_FSEntryVersionT", bound="FSEntryVersion") + + @dataclass class FSEntryVersion: """A version of a file or directory.""" - id: int + id: Optional[int] filedir: "FSEntry" recorded_time: datetime # When was this version recorded? - filetype: str # One of 'LNK', 'DIR', 'REG', etc. See store.FSEntry.from_path for details + filetype: FileType deleted: bool # set True when recording a deleted file unfrozen_perms: str # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--' symlink_target: str # if this is a symlink, this is the (read but not fully # resolved) target. I.e. this is the "content" of the symlink. - sha256: str - source_task_id: int = None + sha256: bytes + source_task_id: Optional[int] = None @classmethod - def from_row(cls, row, filedir=None): - if filedir is None: - filedir = row[1] + def from_row( + cls: Type[_FSEntryVersionT], + row: Tuple[int, int, float, str, bool, str, str, str, Optional[int]], + filedir: "FSEntry", + ) -> _FSEntryVersionT: return cls( - row[0], - filedir, - datetime.fromtimestamp(row[2]), - *row[3:-2], - bytes.fromhex(row[-2]), - row[-1], + row[0], # id + filedir, # filedir + datetime.fromtimestamp(row[2]), # recorded_time + FileType(row[3]), # filetype + row[4], # deleted + row[5], # unfrozen_perms + row[6], # symlink_target + bytes.fromhex(row[7]), # sha256 + row[8], # source_task_id ) +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_FSEntryT = TypeVar("_FSEntryT", bound="FSEntry") + + @dataclass class FSEntry: """A hashed file or directory.""" - id: int # defaults to None + id: Optional[int] # defaults to None filename: str # with parent directory stripped. None if this is the root relpath: str # relative to some root directory - parent: "FSEntry" # upward link + parent: Optional["FSEntry"] # upward link # children for dirs only: non-recursive; files/dirs at this level only children: List["FSEntry"] - filetype: str # regular, symlink, special (block, char, pipe, or socket) - deleted: bool - versions: List[FSEntryVersion] = None + filetype: Optional[ + FileType + ] # regular, symlink, special (block, char, pipe, or socket) + deleted: Optional[bool] + versions: Optional[List[FSEntryVersion]] = None # these will be filled from the version list automatically - unfrozen_perms: str = None # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--' - symlink_target: str = None # if this is a symlink, this is the (read but not fully + unfrozen_perms: Optional[ + str + ] = None # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--' + symlink_target: Optional[ + str + ] = None # if this is a symlink, this is the (read but not fully # resolved) target. I.e. this is the "content" of the symlink. - sha256: str = None - latest_version: FSEntryVersion = None + sha256: Optional[bytes] = None + latest_version: Optional[FSEntryVersion] = None - def __post_init__(self): + def __post_init__(self) -> None: if self.versions is not None and len(self.versions) > 0: self.latest_version = self.versions[-1] self.unfrozen_perms = self.latest_version.unfrozen_perms @@ -126,8 +175,13 @@ class FSEntry: @classmethod def from_path( - cls, root, relpath=None, exclude=["nancy.db"], parent=None, direntry=None - ): + cls: Type[_FSEntryT], + root: PathStr, + relpath: Optional[str] = None, + exclude: List[str] = ["nancy.db"], + parent: Optional[_FSEntryT] = None, + direntry: Optional["os.DirEntry[str]"] = None, + ) -> _FSEntryT: """ Scan a path to instantiate (recursive). @@ -150,14 +204,19 @@ class FSEntry: s = filestat.st_mode children = [] - symlink_target = None + symlink_target: Optional[Union[str, bytes]] = None if os.path.islink(path): # Check links first, since it is not exclusive with dir or file checks - filetype = "LNK" + filetype = FileType.LNK + # readlink returns a str or bytes symlink_target = os.readlink(path) - m.update(bytes(symlink_target, "utf-8")) + assert symlink_target is not None + if isinstance(symlink_target, str): + symlink_target = bytes(symlink_target, "utf-8") + assert isinstance(symlink_target, bytes) + m.update(symlink_target) elif stat.S_ISDIR(s): - filetype = "DIR" + filetype = FileType.DIR # this prevents a directory's hash from colliding with a file hash # in cases where it only holds a single file @@ -189,27 +248,29 @@ class FSEntry: # changes without modifying the hashes of individual files, # which remain content-based for compatibility with # other tools - m.update(bytes(c.unfrozen_perms, "utf-8")) - m.update(c.sha256) + if c.unfrozen_perms is not None: + m.update(bytes(c.unfrozen_perms, "utf-8")) + if c.sha256 is not None: + m.update(c.sha256) elif stat.S_ISREG(s): - filetype = "REG" + filetype = FileType.REG m.update(open(path, "rb").read()) elif stat.S_ISSOCK(s): - filetype = "SOCK" + filetype = FileType.SOCK elif stat.S_ISCHR(s): - filetype = "CHR" + filetype = FileType.CHR elif stat.S_ISBLK(s): - filetype = "BLK" + filetype = FileType.BLK elif stat.S_ISFIFO(s): - filetype = "FIFO" + filetype = FileType.FIFO elif stat.S_ISDOOR(s): - filetype = "DOOR" + filetype = FileType.DOOR elif stat.S_ISPORT(s): - filetype = "PORT" + filetype = FileType.PORT elif stat.S_ISWHT(s): - filetype = "WHT" + filetype = FileType.WHT else: - filetype = "OTHER" + filetype = FileType.OTHER sha256 = m.digest() @@ -221,20 +282,22 @@ class FSEntry: children=children, filetype=None, deleted=None, - versions=[ - FSEntryVersion( - id=None, - filedir=None, - recorded_time=datetime.now().timestamp(), - filetype=filetype, - deleted=False, - unfrozen_perms=stat.filemode(s), - symlink_target=symlink_target, - sha256=sha256, - source_task_id=None, - ) - ], + versions=[], ) + # Update versions after the fact to get self-reference + ob.versions = [ + FSEntryVersion( + id=None, + filedir=ob, + recorded_time=datetime.now(), + filetype=filetype, + deleted=False, + unfrozen_perms=stat.filemode(s), + symlink_target=str(symlink_target), + sha256=sha256, + source_task_id=None, + ) + ] # now change children's parents to point to this object for v in ob.versions: v.filedir = ob @@ -250,7 +313,7 @@ class FSEntry: return ob @classmethod - def empty_root(cls): + def empty_root(cls: Type[_FSEntryT]) -> _FSEntryT: """Just a standardized value indicating an empty root directory""" return cls( id=None, @@ -258,15 +321,23 @@ class FSEntry: relpath=".", parent=None, children=[], - filetype="DIR", + filetype=FileType.DIR, unfrozen_perms="----------", sha256=hashlib.sha256().digest(), deleted=False, ) + # @logger.catch @classmethod - @logger.catch - def from_db_index(cls, cursor, root_id=None, root_row=None, parent=None): + def from_db_index( + cls: Type[_FSEntryT], + cursor: sqlite3.Cursor, + root_id: Optional[int] = None, + root_row: Optional[ + Tuple[int, str, bool] + ] = None, # TODO: Type the expected sqlite rows + parent: Optional[_FSEntryT] = None, + ) -> _FSEntryT: """Given id of an entry in filedir, recursively fill this object""" if root_row is None: assert root_id is not None @@ -315,21 +386,21 @@ class FSEntry: ob.unfrozen_perms = last_ver.unfrozen_perms ob.symlink_target = last_ver.symlink_target ob.sha256 = last_ver.sha256 - ob.last_version = last_ver + ob.latest_version = last_ver return ob - def flatten_tree(self, level=0): + def flatten_tree(self, level: int = 0) -> List[Tuple[int, "FSEntry"]]: """Return list of all entries, with level, in pairs""" pairs = [(level, self)] for c in sorted(self.children, key=lambda e: e.filename): pairs.extend(c.flatten_tree(level=level + 1)) return pairs - def __str__(self): + def __str__(self) -> str: return self.to_string(level=0) - def to_string(self, level=0): + def to_string(self, level: int = 0) -> str: if len(self.children) == 0: childsec = "[]" else: @@ -350,25 +421,33 @@ filetype: {self.filetype} deleted: {self.deleted} unfrozen_perms: {self.unfrozen_perms} symlink_target: {self.symlink_target} -sha256: {self.sha256.hex()} +sha256: {'None' if self.sha256 is None else self.sha256.hex()} children: {childsec} """.splitlines() ) -def sort_diffs_filename(diffs): +def sort_diffs_filename(diffs: List["FSDiff"]) -> List["FSDiff"]: name_ent = {e.filename(): e for e in diffs} return [name_ent[n] for n in sorted(name_ent.keys())] +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_FSDiffT = TypeVar("_FSDiffT", bound="FSDiff") + + @dataclass class FSDiff: - A: FSEntry # record the comparisons - B: FSEntry # a missing entry indicates new or deleted - modified_children: "FSDiff" + A: Optional[FSEntry] # record the comparisons + B: Optional[FSEntry] # a missing entry indicates new or deleted + modified_children: "List[FSDiff]" + + def __post_init__(self) -> None: + if self.A is None and self.B is None: + raise TypeError("A and B cannot both be None") @staticmethod - def compare(A, B): + def compare(A: FSEntry, B: FSEntry) -> bool: return ( A.sha256 == B.sha256 and A.unfrozen_perms == B.unfrozen_perms @@ -376,14 +455,25 @@ class FSDiff: and A.deleted == B.deleted ) - def filename(self): - return self.B.filename if self.A is None else self.A.filename + def filename(self) -> str: + if self.A is not None: + return self.A.filename + else: + assert self.B is not None + return self.B.filename - def filetype(self): - return self.B.filetype if self.A is None else self.A.filetype + def filetype(self) -> Optional[FileType]: + if self.A is not None: + return self.A.filetype + elif self.B is not None: + return self.B.filetype + else: + return None @classmethod - def compute(cls, A, B): + def compute( + cls: Type[_FSDiffT], A: Optional[FSEntry], B: Optional[FSEntry] + ) -> _FSDiffT: """Given two hashed directories, recursively compute difference. This assumes the hashes are consistent, so that directories with @@ -395,14 +485,17 @@ class FSDiff: new (Directory): overlay with new entries from other """ if A is None: # new entry - return cls( - A, - B, - [ - cls.compute(None, c) - for c in sorted(B.children, key=lambda e: e.filename) - ], - ) + if B is None: + raise ValueError("Cannot compute diff with both A and B missing") + else: + return cls( + A, + B, + [ + cls.compute(None, c) + for c in sorted(B.children, key=lambda e: e.filename) + ], + ) if B is None: # deleted entry return cls( A, @@ -435,7 +528,7 @@ class FSDiff: return cls(A, B, modified_children) - def flatten_tree(self, level=0): + def flatten_tree(self, level: int = 0) -> List[Tuple[int, "FSDiff"]]: """Return list of all entries, with level, in pairs""" pairs = [(level, self)] for c in sorted(self.modified_children, key=lambda d: d.filename()): diff --git a/src/nancy/machine.py b/src/nancy/machine.py index f7ede71..bef26d5 100644 --- a/src/nancy/machine.py +++ b/src/nancy/machine.py @@ -1,12 +1,17 @@ -from typing import NamedTuple +from typing import NamedTuple, Optional, Type, TypeVar import json import platform +import sqlite3 import time +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_MachineT = TypeVar("_MachineT", bound="Machine") + + class Machine(NamedTuple): - id: int - machine_id: str + id: Optional[int] + machine_id: Optional[str] hostname: str processor: str system: str @@ -18,7 +23,9 @@ class Machine(NamedTuple): mac_ver: str @classmethod - def find_or_insert(cls, cur, machine=None): + def find_or_insert( + cls: Type[_MachineT], cur: sqlite3.Cursor, machine: Optional[_MachineT] = None + ) -> _MachineT: """Given a DB cursor, find or create row in machine table and fill""" if machine is None: machine = cls.detect() @@ -61,7 +68,7 @@ class Machine(NamedTuple): return machine._replace(id=id) @classmethod - def detect(cls): + def detect(cls: Type[_MachineT]) -> _MachineT: """Formats machine-specific information into a MachineInfo object. Note that 'MachineInfo' objects are properly formatted to be inserted into diff --git a/src/nancy/py.typed b/src/nancy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/nancy/store.py b/src/nancy/store.py index 09dfd10..380f2dd 100644 --- a/src/nancy/store.py +++ b/src/nancy/store.py @@ -4,28 +4,30 @@ from loguru import logger from . import db, environment, fs +from dataclasses import dataclass import datetime import os from pathlib import Path import sqlite3 +from typing import Any, Optional, TypeVar, Type, Union +import warnings +@dataclass class Program: - def __init__(self, store, name, message): - self.store = store - self.name = name - self.message = message + store: "Store" + name: str + message: str - self._evaluated = False + id: Optional[int] = None + start_time: Optional[datetime.datetime] = None + evaluated: bool = False - def set_start_time(self, t): - self.start_time = t - - @logger.catch - def __enter__(self): - if self._evaluated: + def __enter__(self) -> "Program": + if self.evaluated: raise RuntimeError("Cannot re-enter a Program context") + assert self.store.conn is not None cur = self.store.conn.cursor() env = environment.Environment.find_or_insert(cur) @@ -47,23 +49,32 @@ class Program: ) self.id = cur.lastrowid - self.set_start_time(datetime.datetime.now()) + self.start_time = datetime.datetime.now() return self - def new_task(self, name, py_function_id=None): + def new_task(self, name: str, py_function_id: Optional[int] = None) -> int: """Create a new task and return its id""" + assert self.store.conn is not None cur = self.store.conn.cursor() cur.execute( "INSERT INTO task VALUES (?, ?, ?)", (None, self.id, py_function_id), ) - return cur.lastrowid + taskid = cur.lastrowid + assert isinstance(taskid, int) + return taskid - def __exit__(self, exc_type, exc_value, exc_traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + traceback: Optional[Any], + ) -> None: end_time = datetime.datetime.now() # record start and end times in store + assert self.store.conn is not None cur = self.store.conn.cursor() cur.execute( """ @@ -79,6 +90,7 @@ class Program: ) cur.connection.commit() self._evaluated = True # prevent re-running + assert self.start_time is not None elapsed = end_time - self.start_time logger.success( f"Program [{self.id}] {self.name} " @@ -86,10 +98,22 @@ class Program: ) +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_StoreT = TypeVar("_StoreT", bound="Store") + + class Store: """Describes a data directory, holds active connection to nancy.db""" - def __init__(self, directory=None, conn=None): + path: Optional[fs.PathStr] + db_path: fs.PathStr + conn: Optional[sqlite3.Connection] + + def __init__( + self, + directory: Optional[fs.PathStr] = None, + conn: Optional[sqlite3.Connection] = None, + ): """ Arguments: directory (str): Location of existing store directory. If omitted @@ -107,21 +131,24 @@ class Store: else: self.conn = conn - def copy(self, store_path): + def copy(self: _StoreT, store_path: fs.PathStr) -> _StoreT: """Copy this store to a new store path""" + assert self.conn is not None dst_db_path = os.path.join(store_path, "nancy.db") dst_conn = sqlite3.connect(dst_db_path) self.conn.backup(dst_conn) dst_conn.close return self.__class__(store_path) - def connect(self): + def connect(self) -> sqlite3.Connection: self.conn = sqlite3.connect(self.db_path) self.conn.cursor().execute("PRAGMA foreign_keys = ON;") return self.conn @classmethod - def init(cls, directory=None, message=None): + def init( + cls: Type[_StoreT], message: str, directory: Optional[fs.PathStr] = None + ) -> _StoreT: start_time = datetime.datetime.now() if directory is None: # initialize an in-memory store db_path = ":memory:" @@ -142,31 +169,37 @@ class Store: with new_store.program("INIT", message) as p: # set the timing to the actual times it took to initialize the db - p.set_start_time(start_time) + p.start_time = start_time return new_store - def make_readonly(self): + def make_readonly(self) -> None: """Make store directory read-only (except for nancy.db) and return file list""" - fs.make_readonly_recursive(self.path, excluded="./nancy.db") + fs.make_readonly_recursive(str(self.path), excluded=["./nancy.db"]) - def filedir_root_index(self, cur=None): + def filedir_root_index(self, cur: Optional[sqlite3.Cursor] = None) -> Optional[int]: """Get the database id for the table entry in this store having name '.'""" if cur is None: + assert self.conn is not None cur = self.conn.cursor() cur.execute("SELECT id FROM filedir WHERE store=1 AND parent is NULL") (root_id,) = cur.fetchone() + assert isinstance(root_id, int) return root_id - def path_to_fsentry(self, path): + def path_to_fsentry(self, path: fs.PathStr) -> Optional[fs.FSEntry]: """Find a path in the filedir database and return it as an fsentry. If the path is not found in the store, None is returned. """ + assert self.conn is not None cur = self.conn.cursor() # get relative path to resolved path - rel = os.path.relpath(os.path.realpath(path), start=os.path.realpath(self.path)) + rel = os.path.relpath( + os.path.realpath(str(path)), + start=os.path.realpath(str(self.path)), + ) # rel tells us how to descend recurively to find the filedir for path fd_id = self.filedir_root_index(cur) @@ -186,34 +219,41 @@ class Store: return None fd_id, filetype = row - if filetype != "DIR": - return fd_id + return fs.FSEntry.from_db_index(cur, root_id=fd_id) - def fs_entries(self, shallow=False): + def fs_entries(self, shallow: bool = False) -> Optional[fs.FSEntry]: """Return recursive structure containing FSEntry objects from db""" root_id = self.filedir_root_index() if root_id is None: return None else: + assert self.conn is not None return fs.FSEntry.from_db_index(self.conn.cursor(), root_id=root_id) - def program(self, name, message=None): - return Program(self, name, message) + def program(self, name: str, message: str) -> Program: + p = Program(self, name, message) + return p - def diff(self): + def diff(self) -> fs.FSDiff: """ Find changes to files and dirs compared to their recorded versions """ # get info about files currently at the given locations - current = fs.FSEntry.from_path(self.path) + current = fs.FSEntry.from_path(str(self.path)) # then find a listing covering all the expected paths recorded = self.fs_entries(shallow=True) return fs.FSDiff.compute(recorded, current) - def _record_file_version(self, cur, ob, filedir_id, source_task=None): + def _record_file_version( + self, + cur: sqlite3.Cursor, + ob: fs.FSEntry, + filedir_id: int, + source_task: Optional[int] = None, + ) -> int: cur.execute( "INSERT INTO filedir_version VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( @@ -224,13 +264,20 @@ class Store: False, ob.unfrozen_perms, ob.symlink_target, - ob.sha256.hex(), + None if ob.sha256 is None else ob.sha256.hex(), source_task, ), ) + assert isinstance(cur.lastrowid, int) return cur.lastrowid - def _record_new_file_recursive(self, ob, cur, parent_id, source_task): + def _record_new_file_recursive( + self, + ob: fs.FSEntry, + cur: sqlite3.Cursor, + parent_id: Optional[int], + source_task: Optional[int], + ) -> None: # Find entries with this name and parent cur.execute( "SELECT id FROM filedir WHERE store = 1 AND name = ? AND parent = ? LIMIT 1", @@ -252,6 +299,7 @@ class Store: thisid = cur.lastrowid else: (thisid,) = res[0] + assert isinstance(thisid, int) self._record_file_version(cur, ob, thisid, source_task=source_task) @@ -259,25 +307,41 @@ class Store: for c in ob.children: self._record_new_file_recursive(c, cur, thisid, source_task) - def _record_recursive(self, diff, cur, parent_id=None, source_task=None): + def _record_recursive( + self, + diff: fs.FSDiff, + cur: sqlite3.Cursor, + parent_id: Optional[int] = None, + source_task: Optional[int] = None, + ) -> None: """Record this level of a diff.""" if diff.A is None: + assert diff.B is not None self._record_new_file_recursive( diff.B, cur, parent_id, source_task=source_task ) elif diff.B is None: - self._record_deleted_file_recursive(diff.B, cur, parent_id) + # self._record_deleted_file_recursive(diff.B, cur, parent_id) + pass else: # possibly modified, record new version then recurse into children self._record_new_file_recursive( diff.B, cur, parent_id, source_task=source_task ) + assert diff.A.id is not None self._record_file_version(cur, diff.B, diff.A.id, source_task=source_task) # descend into children - def record(self, diff, parent_id=None, message=None, cur=None): + def record( + self, + diff: fs.FSDiff, + message: str, + parent_id: Optional[int] = None, + cur: Optional[sqlite3.Cursor] = None, + ) -> None: if cur is None: + assert self.conn is not None cur = self.conn.cursor() with self.program("RECORD", message) as p: @@ -288,41 +352,20 @@ class Store: # recording new versions of each, when necessary self._record_recursive(diff, cur, source_task=task_id) - # @contextmanager - def run( - self, - name=None, - message=None, - ): - """ - Create a context manager that encapsulates a procedure that can save files. - - Note that this does NOT spawn any new OS processes or threads. - - Example: - - s = nancy.store.init(target_directory) - with s.run("sum_dataframe") as f: - x = PandasDataframe() - y = Sum(x) - f.save('stats/xsum.csv', y) - """ - pass - class StoreFile: """Describes a file that is recorded in the store.""" - def __init__(self, store, rel_path): + def __init__(self, store: Store, rel_path: fs.PathStr): self.store = store self.rel_path = rel_path - def save(self): + def save(self) -> None: # call the appropriate save method pass -def find_store(path): +def find_store(path: Union[str, "os.PathLike[str]"]) -> Optional[str]: """ Given a path, find a store dir containing nancy.db at any level above it. """ diff --git a/src/nancy/user.py b/src/nancy/user.py index cdb026e..0ad9265 100644 --- a/src/nancy/user.py +++ b/src/nancy/user.py @@ -3,25 +3,32 @@ from . import machine import getpass import os import pwd -from typing import NamedTuple +import sqlite3 +from typing import NamedTuple, Optional, Type, TypeVar + + +# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance +_UserT = TypeVar("_UserT", bound="User") class User(NamedTuple): - id: int # if not None, this is `id` in the `machine` table + id: Optional[int] # if not None, this is `id` in the `machine` table username: str userid: int fullname: str machine: machine.Machine @classmethod - def find_or_insert(cls, cur, user=None): + def find_or_insert( + cls: Type[_UserT], cur: sqlite3.Cursor, user: Optional[_UserT] = None + ) -> _UserT: """Given a DB cursor, find or create row in user table and fill""" if user is None: user = cls.detect() m = machine.Machine.find_or_insert(cur) - user = user._replace(machine=m.id) + user = user._replace(machine=m) # insert or ignore, handle each case to set id cur.execute( @@ -37,7 +44,12 @@ class User(NamedTuple): machine = ? LIMIT 1 """, - user[1:], + ( + user.username, + user.userid, + user.fullname, + user.machine.id, + ), ) res = cur.fetchone() if res is None: @@ -45,7 +57,13 @@ class User(NamedTuple): """ INSERT INTO user VALUES (?,?,?,?,?); """, - user, + ( + user.id, + user.username, + user.userid, + user.fullname, + user.machine.id, + ), ) id = cur.lastrowid cur.connection.commit() @@ -55,7 +73,7 @@ class User(NamedTuple): return user._replace(id=id) @classmethod - def detect(cls): + def detect(cls: Type[_UserT]) -> _UserT: """Detect values for user independent of the database. Note that the machine entry will not have a valid id. @@ -70,5 +88,5 @@ class User(NamedTuple): getpass.getuser(), os.getuid(), fullname, - m.id, + m, ) diff --git a/tests/test_store.py b/tests/test_store.py index 497e3b0..a3126ec 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -32,7 +32,7 @@ def test_record_untracked_dir(filled_dir): def store(): from nancy import store - s = store.Store.init() + s = store.Store.init(message="test init") yield s