Add tox with mypy, fix typehints. Upgrade black pre-commit hook

This commit is contained in:
Jacob Hinkle 2022-10-04 20:56:38 -04:00
parent 25ab58bcda
commit c4648ec042
15 changed files with 418 additions and 191 deletions

View File

@ -1,11 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v3.2.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.3.0
hooks:
- id: black
exclude: ^dist/

View File

@ -45,6 +45,17 @@
poetry
pkgs.pre-commit
sqlite
mypy
# additional python interpreters for use with tox
#pkgs.python37
#pkgs.python37Packages.virtualenv
pkgs.python38
pkgs.python38Packages.virtualenv
pkgs.python39
pkgs.python39Packages.virtualenv
tox
];
});
}));

View File

@ -5,7 +5,7 @@ description = "Composable tracking of scientific data provenance"
authors = ["Jacob Hinkle <jacob.hinkle@jhink.org>"]
[tool.poetry.dependencies]
python = "^3.7"
python = "^3.8"
click = "^8.1.3"
colorama = "^0.4.5"
loguru = "^0.6.0"
@ -24,3 +24,23 @@ nancy = "nancy.cli:main"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.tox]
legacy_tox_ini = """
[tox]
envlist = py38,py39,py310,mypy
isolated_build = true
[testenv]
deps =
pytest
pytest-cov
coverage
commands =
pytest --cov src/nancy
[testenv:mypy]
deps = mypy
commands =
mypy --strict -p nancy
"""

View File

@ -7,10 +7,12 @@ from ..version import __version__
from . import diff
from . import record
from typing import Optional
# from https://click.palletsprojects.com/en/5.x/advanced/
class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name):
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
rv = click.Group.get_command(self, ctx, cmd_name)
if rv is not None:
return rv
@ -19,11 +21,11 @@ class AliasedGroup(click.Group):
return None
elif len(matches) == 1:
return click.Group.get_command(self, ctx, matches[0])
ctx.fail("Too many matches: %s" % ", ".join(sorted(matches)))
return ctx.fail("Too many matches: %s" % ", ".join(sorted(matches)))
@click.command()
def version():
def version() -> None:
"""Print version information."""
print(f"nancy v{__version__}")
@ -41,7 +43,7 @@ def version():
default="SUCCESS",
help="If given, print all output including debugging info.",
)
def main(log_level):
def main(log_level: str) -> None:
import sys
logger.remove()

View File

@ -1,4 +1,7 @@
def confirm(question, default_no=False):
def confirm(
question: str,
default_no: bool = False,
) -> bool:
"""Ask a question and wait for a Y/N response."""
choices = " [y/N]: " if default_no else " [Y/n]: "

View File

@ -10,8 +10,12 @@ import warnings
def print_diff(
ABdiff: fs.FSDiff, indent=2, indent_level=0, use_color=True, show_hashes=False
):
ABdiff: fs.FSDiff,
indent: int = 2,
indent_level: int = 0,
use_color: bool = True,
show_hashes: bool = False,
) -> None:
"""Pretty print an FSDiff object"""
if use_color:
try:
@ -33,7 +37,7 @@ def print_diff(
reset = Style.RESET_ALL if use_color else ""
hashcolor = Fore.MAGENTA if use_color else ""
def _print_row(tag, entry, level):
def _print_row(tag: str, entry: fs.FSEntry, level: int) -> None:
relpath = entry.relpath
# Format relpath using filetype-based colors
@ -95,7 +99,7 @@ def print_diff(
"store is initialized there.",
)
@logger.catch
def status(show_hashes, no_color, store):
def status(show_hashes: bool, no_color: bool, store: str) -> None:
"""Detect and describe changes to PATH
PATH is a path to a file or directory inside an existing nancy store

View File

@ -8,17 +8,18 @@ from .diff import print_diff
import os
import sys
from typing import Any, Optional, Union
@logger.catch
def record(
message,
store_path=None,
show_diff=True,
show_hashes=False,
use_color=True,
skip_confirm=False,
):
message: str,
store_path: Optional[Union[str, "os.PathLike[Any]"]] = None,
show_diff: bool = True,
show_hashes: bool = False,
use_color: bool = True,
skip_confirm: bool = False,
) -> None:
"""Unwrapped record command"""
if store_path is None:

View File

@ -1,5 +1,7 @@
import importlib.resources
import os
import sqlite3
from typing import Any, Union
import warnings
@ -24,7 +26,7 @@ if (
)
def init_schema(cur):
def init_schema(cur: sqlite3.Cursor) -> None:
"""Initialize a database following the current schema."""
schema = importlib.resources.read_text(
"nancy.schema",
@ -33,7 +35,7 @@ def init_schema(cur):
cur.executescript(schema)
def connect(path):
def connect(path: "os.PathLike[Any]") -> sqlite3.Connection:
conn = sqlite3.connect(path)
conn.cursor().execute("PRAGMA foreign_keys = ON;")
return conn

View File

@ -1,14 +1,19 @@
from . import user
from typing import NamedTuple
import json
import os
import platform
import sqlite3
import sys
from typing import NamedTuple, Optional, TypeVar, Type
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_EnvironmentT = TypeVar("_EnvironmentT", bound="Environment")
class Environment(NamedTuple):
id: int
id: Optional[int]
envvars_json: str
python_implementation: str
python_strversion: str
@ -16,14 +21,18 @@ class Environment(NamedTuple):
user: user.User
@classmethod
def find_or_insert(cls, cur, env=None):
def find_or_insert(
cls: Type[_EnvironmentT],
cur: sqlite3.Cursor,
env: Optional[_EnvironmentT] = None,
) -> _EnvironmentT:
"""Given a DB cursor, find or create row in environment table and fill"""
if env is None:
env = cls.detect()
u = user.User.find_or_insert(cur)
env = env._replace(user=u.id)
env = env._replace(user=u)
# insert or ignore, handle each case to set id
cur.execute(
@ -40,7 +49,13 @@ class Environment(NamedTuple):
user = ?
LIMIT 1
""",
env[1:],
(
env.envvars_json,
env.python_implementation,
env.python_strversion,
env.python_hexversion,
env.user.id,
),
)
res = cur.fetchone()
if res is None:
@ -48,7 +63,14 @@ class Environment(NamedTuple):
"""
INSERT INTO environment VALUES (?,?,?,?,?,?);
""",
env,
(
env.id,
env.envvars_json,
env.python_implementation,
env.python_strversion,
env.python_hexversion,
env.user.id,
),
)
id = cur.lastrowid
cur.connection.commit()
@ -58,7 +80,7 @@ class Environment(NamedTuple):
return env._replace(id=id)
@classmethod
def detect(cls):
def detect(cls: Type[_EnvironmentT]) -> _EnvironmentT:
"""Detect values for environment independent of the database.
Note that the user entry will not have a valid id.
@ -71,5 +93,5 @@ class Environment(NamedTuple):
platform.python_implementation(),
sys.version,
sys.hexversion,
u.id,
u,
)

View File

@ -4,15 +4,21 @@ from loguru import logger
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
import hashlib
import operator
import os
from pathlib import Path
import sqlite3
import stat
from typing import List
from typing import Any, AnyStr, List, Optional, Tuple, TypeVar, Type, Union
import warnings
def remove_write_perms(path):
PathStr = Union[str, Path, "os.PathLike[str]"]
def remove_write_perms(path: PathStr) -> Optional[str]:
"""Remove write permissions for all users while preserving other perms"""
if not os.path.islink(path):
s = os.stat(path)
@ -49,15 +55,18 @@ def remove_write_perms(path):
return orig_perm_string
def make_readonly_recursive(path, excluded=[]):
def make_readonly_recursive(
path: PathStr,
excluded: List[PathStr] = [],
) -> None:
"""Recursively "freeze" a directory by setting all files and directories read-only"""
# traversing bottom-up makes it easier to freeze perms on directories
for root, dirs, files in os.walk(path, topdown=False):
for root, dirs, files in os.walk(str(path), topdown=False):
for f in files:
p = os.path.join(root, f)
if p in excluded:
continue
remove_write_perms(os.path.join(path, p))
remove_write_perms(os.path.join(Path(path), p))
for d in dirs:
p = os.path.join(root, d)
@ -66,58 +75,98 @@ def make_readonly_recursive(path, excluded=[]):
remove_write_perms(os.path.join(path, p))
class FileType(Enum):
"""One of 'LNK', 'DIR', 'REG', etc.
names are compatible with those used in the `stat` module.
See :meth:'store.FSEntry.from_path' for details.
"""
BLK = "BLK"
CHR = "CHR"
DIR = "DIR"
DOOR = "DOOR"
FIFO = "FIFO"
LNK = "LNK"
OTHER = "OTHER"
REG = "REG"
PORT = "PORT"
SOCK = "SOCK"
WHT = "WHT"
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_FSEntryVersionT = TypeVar("_FSEntryVersionT", bound="FSEntryVersion")
@dataclass
class FSEntryVersion:
"""A version of a file or directory."""
id: int
id: Optional[int]
filedir: "FSEntry"
recorded_time: datetime # When was this version recorded?
filetype: str # One of 'LNK', 'DIR', 'REG', etc. See store.FSEntry.from_path for details
filetype: FileType
deleted: bool # set True when recording a deleted file
unfrozen_perms: str # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--'
symlink_target: str # if this is a symlink, this is the (read but not fully
# resolved) target. I.e. this is the "content" of the symlink.
sha256: str
source_task_id: int = None
sha256: bytes
source_task_id: Optional[int] = None
@classmethod
def from_row(cls, row, filedir=None):
if filedir is None:
filedir = row[1]
def from_row(
cls: Type[_FSEntryVersionT],
row: Tuple[int, int, float, str, bool, str, str, str, Optional[int]],
filedir: "FSEntry",
) -> _FSEntryVersionT:
return cls(
row[0],
filedir,
datetime.fromtimestamp(row[2]),
*row[3:-2],
bytes.fromhex(row[-2]),
row[-1],
row[0], # id
filedir, # filedir
datetime.fromtimestamp(row[2]), # recorded_time
FileType(row[3]), # filetype
row[4], # deleted
row[5], # unfrozen_perms
row[6], # symlink_target
bytes.fromhex(row[7]), # sha256
row[8], # source_task_id
)
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_FSEntryT = TypeVar("_FSEntryT", bound="FSEntry")
@dataclass
class FSEntry:
"""A hashed file or directory."""
id: int # defaults to None
id: Optional[int] # defaults to None
filename: str # with parent directory stripped. None if this is the root
relpath: str # relative to some root directory
parent: "FSEntry" # upward link
parent: Optional["FSEntry"] # upward link
# children for dirs only: non-recursive; files/dirs at this level only
children: List["FSEntry"]
filetype: str # regular, symlink, special (block, char, pipe, or socket)
deleted: bool
versions: List[FSEntryVersion] = None
filetype: Optional[
FileType
] # regular, symlink, special (block, char, pipe, or socket)
deleted: Optional[bool]
versions: Optional[List[FSEntryVersion]] = None
# these will be filled from the version list automatically
unfrozen_perms: str = None # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--'
symlink_target: str = None # if this is a symlink, this is the (read but not fully
unfrozen_perms: Optional[
str
] = None # stat.filemode(os.stat(path).st_mode): '-rw-rw-r--'
symlink_target: Optional[
str
] = None # if this is a symlink, this is the (read but not fully
# resolved) target. I.e. this is the "content" of the symlink.
sha256: str = None
latest_version: FSEntryVersion = None
sha256: Optional[bytes] = None
latest_version: Optional[FSEntryVersion] = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.versions is not None and len(self.versions) > 0:
self.latest_version = self.versions[-1]
self.unfrozen_perms = self.latest_version.unfrozen_perms
@ -126,8 +175,13 @@ class FSEntry:
@classmethod
def from_path(
cls, root, relpath=None, exclude=["nancy.db"], parent=None, direntry=None
):
cls: Type[_FSEntryT],
root: PathStr,
relpath: Optional[str] = None,
exclude: List[str] = ["nancy.db"],
parent: Optional[_FSEntryT] = None,
direntry: Optional["os.DirEntry[str]"] = None,
) -> _FSEntryT:
"""
Scan a path to instantiate (recursive).
@ -150,14 +204,19 @@ class FSEntry:
s = filestat.st_mode
children = []
symlink_target = None
symlink_target: Optional[Union[str, bytes]] = None
if os.path.islink(path):
# Check links first, since it is not exclusive with dir or file checks
filetype = "LNK"
filetype = FileType.LNK
# readlink returns a str or bytes
symlink_target = os.readlink(path)
m.update(bytes(symlink_target, "utf-8"))
assert symlink_target is not None
if isinstance(symlink_target, str):
symlink_target = bytes(symlink_target, "utf-8")
assert isinstance(symlink_target, bytes)
m.update(symlink_target)
elif stat.S_ISDIR(s):
filetype = "DIR"
filetype = FileType.DIR
# this prevents a directory's hash from colliding with a file hash
# in cases where it only holds a single file
@ -189,27 +248,29 @@ class FSEntry:
# changes without modifying the hashes of individual files,
# which remain content-based for compatibility with
# other tools
if c.unfrozen_perms is not None:
m.update(bytes(c.unfrozen_perms, "utf-8"))
if c.sha256 is not None:
m.update(c.sha256)
elif stat.S_ISREG(s):
filetype = "REG"
filetype = FileType.REG
m.update(open(path, "rb").read())
elif stat.S_ISSOCK(s):
filetype = "SOCK"
filetype = FileType.SOCK
elif stat.S_ISCHR(s):
filetype = "CHR"
filetype = FileType.CHR
elif stat.S_ISBLK(s):
filetype = "BLK"
filetype = FileType.BLK
elif stat.S_ISFIFO(s):
filetype = "FIFO"
filetype = FileType.FIFO
elif stat.S_ISDOOR(s):
filetype = "DOOR"
filetype = FileType.DOOR
elif stat.S_ISPORT(s):
filetype = "PORT"
filetype = FileType.PORT
elif stat.S_ISWHT(s):
filetype = "WHT"
filetype = FileType.WHT
else:
filetype = "OTHER"
filetype = FileType.OTHER
sha256 = m.digest()
@ -221,20 +282,22 @@ class FSEntry:
children=children,
filetype=None,
deleted=None,
versions=[
versions=[],
)
# Update versions after the fact to get self-reference
ob.versions = [
FSEntryVersion(
id=None,
filedir=None,
recorded_time=datetime.now().timestamp(),
filedir=ob,
recorded_time=datetime.now(),
filetype=filetype,
deleted=False,
unfrozen_perms=stat.filemode(s),
symlink_target=symlink_target,
symlink_target=str(symlink_target),
sha256=sha256,
source_task_id=None,
)
],
)
]
# now change children's parents to point to this object
for v in ob.versions:
v.filedir = ob
@ -250,7 +313,7 @@ class FSEntry:
return ob
@classmethod
def empty_root(cls):
def empty_root(cls: Type[_FSEntryT]) -> _FSEntryT:
"""Just a standardized value indicating an empty root directory"""
return cls(
id=None,
@ -258,15 +321,23 @@ class FSEntry:
relpath=".",
parent=None,
children=[],
filetype="DIR",
filetype=FileType.DIR,
unfrozen_perms="----------",
sha256=hashlib.sha256().digest(),
deleted=False,
)
# @logger.catch
@classmethod
@logger.catch
def from_db_index(cls, cursor, root_id=None, root_row=None, parent=None):
def from_db_index(
cls: Type[_FSEntryT],
cursor: sqlite3.Cursor,
root_id: Optional[int] = None,
root_row: Optional[
Tuple[int, str, bool]
] = None, # TODO: Type the expected sqlite rows
parent: Optional[_FSEntryT] = None,
) -> _FSEntryT:
"""Given id of an entry in filedir, recursively fill this object"""
if root_row is None:
assert root_id is not None
@ -315,21 +386,21 @@ class FSEntry:
ob.unfrozen_perms = last_ver.unfrozen_perms
ob.symlink_target = last_ver.symlink_target
ob.sha256 = last_ver.sha256
ob.last_version = last_ver
ob.latest_version = last_ver
return ob
def flatten_tree(self, level=0):
def flatten_tree(self, level: int = 0) -> List[Tuple[int, "FSEntry"]]:
"""Return list of all entries, with level, in pairs"""
pairs = [(level, self)]
for c in sorted(self.children, key=lambda e: e.filename):
pairs.extend(c.flatten_tree(level=level + 1))
return pairs
def __str__(self):
def __str__(self) -> str:
return self.to_string(level=0)
def to_string(self, level=0):
def to_string(self, level: int = 0) -> str:
if len(self.children) == 0:
childsec = "[]"
else:
@ -350,25 +421,33 @@ filetype: {self.filetype}
deleted: {self.deleted}
unfrozen_perms: {self.unfrozen_perms}
symlink_target: {self.symlink_target}
sha256: {self.sha256.hex()}
sha256: {'None' if self.sha256 is None else self.sha256.hex()}
children: {childsec}
""".splitlines()
)
def sort_diffs_filename(diffs):
def sort_diffs_filename(diffs: List["FSDiff"]) -> List["FSDiff"]:
name_ent = {e.filename(): e for e in diffs}
return [name_ent[n] for n in sorted(name_ent.keys())]
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_FSDiffT = TypeVar("_FSDiffT", bound="FSDiff")
@dataclass
class FSDiff:
A: FSEntry # record the comparisons
B: FSEntry # a missing entry indicates new or deleted
modified_children: "FSDiff"
A: Optional[FSEntry] # record the comparisons
B: Optional[FSEntry] # a missing entry indicates new or deleted
modified_children: "List[FSDiff]"
def __post_init__(self) -> None:
if self.A is None and self.B is None:
raise TypeError("A and B cannot both be None")
@staticmethod
def compare(A, B):
def compare(A: FSEntry, B: FSEntry) -> bool:
return (
A.sha256 == B.sha256
and A.unfrozen_perms == B.unfrozen_perms
@ -376,14 +455,25 @@ class FSDiff:
and A.deleted == B.deleted
)
def filename(self):
return self.B.filename if self.A is None else self.A.filename
def filename(self) -> str:
if self.A is not None:
return self.A.filename
else:
assert self.B is not None
return self.B.filename
def filetype(self):
return self.B.filetype if self.A is None else self.A.filetype
def filetype(self) -> Optional[FileType]:
if self.A is not None:
return self.A.filetype
elif self.B is not None:
return self.B.filetype
else:
return None
@classmethod
def compute(cls, A, B):
def compute(
cls: Type[_FSDiffT], A: Optional[FSEntry], B: Optional[FSEntry]
) -> _FSDiffT:
"""Given two hashed directories, recursively compute difference.
This assumes the hashes are consistent, so that directories with
@ -395,6 +485,9 @@ class FSDiff:
new (Directory): overlay with new entries from other
"""
if A is None: # new entry
if B is None:
raise ValueError("Cannot compute diff with both A and B missing")
else:
return cls(
A,
B,
@ -435,7 +528,7 @@ class FSDiff:
return cls(A, B, modified_children)
def flatten_tree(self, level=0):
def flatten_tree(self, level: int = 0) -> List[Tuple[int, "FSDiff"]]:
"""Return list of all entries, with level, in pairs"""
pairs = [(level, self)]
for c in sorted(self.modified_children, key=lambda d: d.filename()):

View File

@ -1,12 +1,17 @@
from typing import NamedTuple
from typing import NamedTuple, Optional, Type, TypeVar
import json
import platform
import sqlite3
import time
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_MachineT = TypeVar("_MachineT", bound="Machine")
class Machine(NamedTuple):
id: int
machine_id: str
id: Optional[int]
machine_id: Optional[str]
hostname: str
processor: str
system: str
@ -18,7 +23,9 @@ class Machine(NamedTuple):
mac_ver: str
@classmethod
def find_or_insert(cls, cur, machine=None):
def find_or_insert(
cls: Type[_MachineT], cur: sqlite3.Cursor, machine: Optional[_MachineT] = None
) -> _MachineT:
"""Given a DB cursor, find or create row in machine table and fill"""
if machine is None:
machine = cls.detect()
@ -61,7 +68,7 @@ class Machine(NamedTuple):
return machine._replace(id=id)
@classmethod
def detect(cls):
def detect(cls: Type[_MachineT]) -> _MachineT:
"""Formats machine-specific information into a MachineInfo object.
Note that 'MachineInfo' objects are properly formatted to be inserted into

0
src/nancy/py.typed Normal file
View File

View File

@ -4,28 +4,30 @@ from loguru import logger
from . import db, environment, fs
from dataclasses import dataclass
import datetime
import os
from pathlib import Path
import sqlite3
from typing import Any, Optional, TypeVar, Type, Union
import warnings
@dataclass
class Program:
def __init__(self, store, name, message):
self.store = store
self.name = name
self.message = message
store: "Store"
name: str
message: str
self._evaluated = False
id: Optional[int] = None
start_time: Optional[datetime.datetime] = None
evaluated: bool = False
def set_start_time(self, t):
self.start_time = t
@logger.catch
def __enter__(self):
if self._evaluated:
def __enter__(self) -> "Program":
if self.evaluated:
raise RuntimeError("Cannot re-enter a Program context")
assert self.store.conn is not None
cur = self.store.conn.cursor()
env = environment.Environment.find_or_insert(cur)
@ -47,23 +49,32 @@ class Program:
)
self.id = cur.lastrowid
self.set_start_time(datetime.datetime.now())
self.start_time = datetime.datetime.now()
return self
def new_task(self, name, py_function_id=None):
def new_task(self, name: str, py_function_id: Optional[int] = None) -> int:
"""Create a new task and return its id"""
assert self.store.conn is not None
cur = self.store.conn.cursor()
cur.execute(
"INSERT INTO task VALUES (?, ?, ?)",
(None, self.id, py_function_id),
)
return cur.lastrowid
taskid = cur.lastrowid
assert isinstance(taskid, int)
return taskid
def __exit__(self, exc_type, exc_value, exc_traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[Any],
) -> None:
end_time = datetime.datetime.now()
# record start and end times in store
assert self.store.conn is not None
cur = self.store.conn.cursor()
cur.execute(
"""
@ -79,6 +90,7 @@ class Program:
)
cur.connection.commit()
self._evaluated = True # prevent re-running
assert self.start_time is not None
elapsed = end_time - self.start_time
logger.success(
f"Program [{self.id}] {self.name} "
@ -86,10 +98,22 @@ class Program:
)
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_StoreT = TypeVar("_StoreT", bound="Store")
class Store:
"""Describes a data directory, holds active connection to nancy.db"""
def __init__(self, directory=None, conn=None):
path: Optional[fs.PathStr]
db_path: fs.PathStr
conn: Optional[sqlite3.Connection]
def __init__(
self,
directory: Optional[fs.PathStr] = None,
conn: Optional[sqlite3.Connection] = None,
):
"""
Arguments:
directory (str): Location of existing store directory. If omitted
@ -107,21 +131,24 @@ class Store:
else:
self.conn = conn
def copy(self, store_path):
def copy(self: _StoreT, store_path: fs.PathStr) -> _StoreT:
"""Copy this store to a new store path"""
assert self.conn is not None
dst_db_path = os.path.join(store_path, "nancy.db")
dst_conn = sqlite3.connect(dst_db_path)
self.conn.backup(dst_conn)
dst_conn.close
return self.__class__(store_path)
def connect(self):
def connect(self) -> sqlite3.Connection:
self.conn = sqlite3.connect(self.db_path)
self.conn.cursor().execute("PRAGMA foreign_keys = ON;")
return self.conn
@classmethod
def init(cls, directory=None, message=None):
def init(
cls: Type[_StoreT], message: str, directory: Optional[fs.PathStr] = None
) -> _StoreT:
start_time = datetime.datetime.now()
if directory is None: # initialize an in-memory store
db_path = ":memory:"
@ -142,31 +169,37 @@ class Store:
with new_store.program("INIT", message) as p:
# set the timing to the actual times it took to initialize the db
p.set_start_time(start_time)
p.start_time = start_time
return new_store
def make_readonly(self):
def make_readonly(self) -> None:
"""Make store directory read-only (except for nancy.db) and return file list"""
fs.make_readonly_recursive(self.path, excluded="./nancy.db")
fs.make_readonly_recursive(str(self.path), excluded=["./nancy.db"])
def filedir_root_index(self, cur=None):
def filedir_root_index(self, cur: Optional[sqlite3.Cursor] = None) -> Optional[int]:
"""Get the database id for the table entry in this store having name '.'"""
if cur is None:
assert self.conn is not None
cur = self.conn.cursor()
cur.execute("SELECT id FROM filedir WHERE store=1 AND parent is NULL")
(root_id,) = cur.fetchone()
assert isinstance(root_id, int)
return root_id
def path_to_fsentry(self, path):
def path_to_fsentry(self, path: fs.PathStr) -> Optional[fs.FSEntry]:
"""Find a path in the filedir database and return it as an fsentry.
If the path is not found in the store, None is returned.
"""
assert self.conn is not None
cur = self.conn.cursor()
# get relative path to resolved path
rel = os.path.relpath(os.path.realpath(path), start=os.path.realpath(self.path))
rel = os.path.relpath(
os.path.realpath(str(path)),
start=os.path.realpath(str(self.path)),
)
# rel tells us how to descend recurively to find the filedir for path
fd_id = self.filedir_root_index(cur)
@ -186,34 +219,41 @@ class Store:
return None
fd_id, filetype = row
if filetype != "DIR":
return fd_id
return fs.FSEntry.from_db_index(cur, root_id=fd_id)
def fs_entries(self, shallow=False):
def fs_entries(self, shallow: bool = False) -> Optional[fs.FSEntry]:
"""Return recursive structure containing FSEntry objects from db"""
root_id = self.filedir_root_index()
if root_id is None:
return None
else:
assert self.conn is not None
return fs.FSEntry.from_db_index(self.conn.cursor(), root_id=root_id)
def program(self, name, message=None):
return Program(self, name, message)
def program(self, name: str, message: str) -> Program:
p = Program(self, name, message)
return p
def diff(self):
def diff(self) -> fs.FSDiff:
"""
Find changes to files and dirs compared to their recorded versions
"""
# get info about files currently at the given locations
current = fs.FSEntry.from_path(self.path)
current = fs.FSEntry.from_path(str(self.path))
# then find a listing covering all the expected paths
recorded = self.fs_entries(shallow=True)
return fs.FSDiff.compute(recorded, current)
def _record_file_version(self, cur, ob, filedir_id, source_task=None):
def _record_file_version(
self,
cur: sqlite3.Cursor,
ob: fs.FSEntry,
filedir_id: int,
source_task: Optional[int] = None,
) -> int:
cur.execute(
"INSERT INTO filedir_version VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
@ -224,13 +264,20 @@ class Store:
False,
ob.unfrozen_perms,
ob.symlink_target,
ob.sha256.hex(),
None if ob.sha256 is None else ob.sha256.hex(),
source_task,
),
)
assert isinstance(cur.lastrowid, int)
return cur.lastrowid
def _record_new_file_recursive(self, ob, cur, parent_id, source_task):
def _record_new_file_recursive(
self,
ob: fs.FSEntry,
cur: sqlite3.Cursor,
parent_id: Optional[int],
source_task: Optional[int],
) -> None:
# Find entries with this name and parent
cur.execute(
"SELECT id FROM filedir WHERE store = 1 AND name = ? AND parent = ? LIMIT 1",
@ -252,6 +299,7 @@ class Store:
thisid = cur.lastrowid
else:
(thisid,) = res[0]
assert isinstance(thisid, int)
self._record_file_version(cur, ob, thisid, source_task=source_task)
@ -259,25 +307,41 @@ class Store:
for c in ob.children:
self._record_new_file_recursive(c, cur, thisid, source_task)
def _record_recursive(self, diff, cur, parent_id=None, source_task=None):
def _record_recursive(
self,
diff: fs.FSDiff,
cur: sqlite3.Cursor,
parent_id: Optional[int] = None,
source_task: Optional[int] = None,
) -> None:
"""Record this level of a diff."""
if diff.A is None:
assert diff.B is not None
self._record_new_file_recursive(
diff.B, cur, parent_id, source_task=source_task
)
elif diff.B is None:
self._record_deleted_file_recursive(diff.B, cur, parent_id)
# self._record_deleted_file_recursive(diff.B, cur, parent_id)
pass
else:
# possibly modified, record new version then recurse into children
self._record_new_file_recursive(
diff.B, cur, parent_id, source_task=source_task
)
assert diff.A.id is not None
self._record_file_version(cur, diff.B, diff.A.id, source_task=source_task)
# descend into children
def record(self, diff, parent_id=None, message=None, cur=None):
def record(
self,
diff: fs.FSDiff,
message: str,
parent_id: Optional[int] = None,
cur: Optional[sqlite3.Cursor] = None,
) -> None:
if cur is None:
assert self.conn is not None
cur = self.conn.cursor()
with self.program("RECORD", message) as p:
@ -288,41 +352,20 @@ class Store:
# recording new versions of each, when necessary
self._record_recursive(diff, cur, source_task=task_id)
# @contextmanager
def run(
self,
name=None,
message=None,
):
"""
Create a context manager that encapsulates a procedure that can save files.
Note that this does NOT spawn any new OS processes or threads.
Example:
s = nancy.store.init(target_directory)
with s.run("sum_dataframe") as f:
x = PandasDataframe()
y = Sum(x)
f.save('stats/xsum.csv', y)
"""
pass
class StoreFile:
"""Describes a file that is recorded in the store."""
def __init__(self, store, rel_path):
def __init__(self, store: Store, rel_path: fs.PathStr):
self.store = store
self.rel_path = rel_path
def save(self):
def save(self) -> None:
# call the appropriate save method
pass
def find_store(path):
def find_store(path: Union[str, "os.PathLike[str]"]) -> Optional[str]:
"""
Given a path, find a store dir containing nancy.db at any level above it.
"""

View File

@ -3,25 +3,32 @@ from . import machine
import getpass
import os
import pwd
from typing import NamedTuple
import sqlite3
from typing import NamedTuple, Optional, Type, TypeVar
# see https://stackoverflow.com/questions/44640479/type-annotation-for-classmethod-returning-instance
_UserT = TypeVar("_UserT", bound="User")
class User(NamedTuple):
id: int # if not None, this is `id` in the `machine` table
id: Optional[int] # if not None, this is `id` in the `machine` table
username: str
userid: int
fullname: str
machine: machine.Machine
@classmethod
def find_or_insert(cls, cur, user=None):
def find_or_insert(
cls: Type[_UserT], cur: sqlite3.Cursor, user: Optional[_UserT] = None
) -> _UserT:
"""Given a DB cursor, find or create row in user table and fill"""
if user is None:
user = cls.detect()
m = machine.Machine.find_or_insert(cur)
user = user._replace(machine=m.id)
user = user._replace(machine=m)
# insert or ignore, handle each case to set id
cur.execute(
@ -37,7 +44,12 @@ class User(NamedTuple):
machine = ?
LIMIT 1
""",
user[1:],
(
user.username,
user.userid,
user.fullname,
user.machine.id,
),
)
res = cur.fetchone()
if res is None:
@ -45,7 +57,13 @@ class User(NamedTuple):
"""
INSERT INTO user VALUES (?,?,?,?,?);
""",
user,
(
user.id,
user.username,
user.userid,
user.fullname,
user.machine.id,
),
)
id = cur.lastrowid
cur.connection.commit()
@ -55,7 +73,7 @@ class User(NamedTuple):
return user._replace(id=id)
@classmethod
def detect(cls):
def detect(cls: Type[_UserT]) -> _UserT:
"""Detect values for user independent of the database.
Note that the machine entry will not have a valid id.
@ -70,5 +88,5 @@ class User(NamedTuple):
getpass.getuser(),
os.getuid(),
fullname,
m.id,
m,
)

View File

@ -32,7 +32,7 @@ def test_record_untracked_dir(filled_dir):
def store():
from nancy import store
s = store.Store.init()
s = store.Store.init(message="test init")
yield s