Coverage for src/devboard/projects.py: 39.52%
140 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-19 20:21 +0100
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-19 20:21 +0100
1"""Git utilities."""
3from __future__ import annotations
5import contextlib
6import re
7from collections import defaultdict
8from contextlib import contextmanager, suppress
9from dataclasses import dataclass
10from pathlib import Path
11from threading import Lock
12from typing import ClassVar, Iterator
14from git import Commit, GitCommandError, Head, Repo, TagReference # type: ignore[attr-defined]
17@dataclass
18class Status:
19 """Git status data."""
21 added: list[Path]
22 """Added files."""
23 deleted: list[Path]
24 """Deleted files."""
25 modified: list[Path]
26 """Modified files."""
27 renamed: list[Path]
28 """Renamed files."""
29 typechanged: list[Path]
30 """Type-changed files."""
31 untracked: list[Path]
32 """Untracked files."""
35@dataclass(eq=True, order=True, frozen=True)
36class Project:
37 """A class representing development projects.
39 It is instantiated with a path, and then provides
40 many utility properties and methods.
41 """
43 LOCKS: ClassVar[dict[Project, Lock]] = defaultdict(Lock)
45 DEFAULT_BRANCHES: ClassVar[tuple[str, ...]] = ("main", "master")
46 """Name of common default branches. Mainly useful to compute unreleased commits."""
47 path: Path
48 """Path of the project on the file-system."""
50 def __str__(self) -> str:
51 return self.name
53 @property
54 def repo(self) -> Repo:
55 """GitPython's `Repo` object."""
56 return Repo(self.path)
58 @property
59 def name(self) -> str:
60 """Name of the project."""
61 return self.path.name
63 @property
64 def is_dirty(self) -> bool:
65 """Whether the project is in a "dirty" state (uncommitted modifications)."""
66 return self.repo.is_dirty(untracked_files=True)
68 @property
69 def status(self) -> Status:
70 """Status of the project."""
71 diff = self.repo.index.diff(None)
72 return Status(
73 added=[Path(added) for added in diff.iter_change_type("A")],
74 deleted=[Path(deleted.a_path) for deleted in diff.iter_change_type("D")],
75 modified=[Path(modified.a_path) for modified in diff.iter_change_type("M")],
76 renamed=[Path(renamed) for renamed in diff.iter_change_type("R")],
77 typechanged=[Path(typechanged) for typechanged in diff.iter_change_type("T")],
78 untracked=[Path(untracked) for untracked in self.repo.untracked_files],
79 )
81 @property
82 def status_line(self) -> str:
83 """Status of the project, as a string."""
84 st = self.status
85 parts = []
86 if added := len(st.added):
87 parts.append(f"{added}A")
88 if deleted := len(st.deleted):
89 parts.append(f"{deleted}D")
90 if modified := len(st.modified):
91 parts.append(f"{modified}M")
92 if renamed := len(st.renamed):
93 parts.append(f"{renamed}R")
94 if typechanged := len(st.typechanged):
95 parts.append(f"{typechanged}T")
96 if untracked := len(st.untracked):
97 parts.append(f"{untracked}U")
98 return " ".join(parts)
100 def unpushed(self, remote: str = "origin") -> dict[str, int]:
101 """Number of unpushed commits, per branch."""
102 result = {}
103 for branch in self.repo.branches: # type: ignore[attr-defined]
104 with contextlib.suppress(GitCommandError):
105 result[branch.name] = len(list(self.repo.iter_commits(f"{remote}/{branch.name}..{branch.name}")))
106 return result
108 def unpulled(self, remote: str = "origin") -> dict[str, int]:
109 """Number of unpulled commits, per branch."""
110 result = {}
111 for branch in self.repo.branches: # type: ignore[attr-defined]
112 with contextlib.suppress(GitCommandError):
113 result[branch.name] = len(list(self.repo.iter_commits(f"{branch.name}..{remote}/{branch.name}")))
114 return result
116 @property
117 def branch(self) -> Head:
118 """Currently checked out branch."""
119 return self.repo.active_branch
121 @property
122 def default_branch(self) -> str:
123 """Default branch (or main branch), as checked out when cloning."""
124 for branch in self.DEFAULT_BRANCHES:
125 if branch in self.repo.references:
126 return branch
127 try:
128 origin = self.repo.git.remote("show", "origin")
129 except GitCommandError as error:
130 raise ValueError(f"Cannot infer default branch for repo {self.name}") from error
131 if match := re.search(r"\s*HEAD branch:\s*(.*)", origin):
132 return match.group(1)
133 raise ValueError(f"Cannot infer default branch for repo {self.name}")
135 @contextmanager
136 def checkout(self, branch: str | None) -> Iterator[None]:
137 """Checkout branch, restore previous one when exiting."""
138 if not branch:
139 yield
140 return
141 current = self.branch
142 if branch == current:
143 yield
144 return
145 self.repo.branches[branch].checkout() # type: ignore[index]
146 try:
147 yield
148 finally:
149 current.checkout()
151 def pull(self, branch: str | None = None) -> None:
152 """Pull branch."""
153 with self.checkout(branch):
154 self.repo.remotes.origin.pull()
156 def push(self, branch: str | None = None) -> None:
157 """Push branch."""
158 with self.checkout(branch):
159 self.repo.remotes.origin.push()
161 def delete(self, branch: str) -> None:
162 """Delete branch."""
163 self.repo.delete_head(branch, force=True)
165 def unreleased(self, branch: str | None = None) -> list[Commit]:
166 """List unreleased commits."""
167 commits = []
168 if branch is None:
169 try:
170 branch = self.default_branch
171 except ValueError:
172 return []
173 iterator = self.repo.iter_commits(branch)
174 try:
175 latest_tagged_commit = self.latest_tag.commit
176 except IndexError:
177 return list(iterator)
178 for commit in iterator:
179 if commit == latest_tagged_commit:
180 break
181 commits.append(commit)
182 return commits
184 def fetch(self) -> None:
185 """Fetch."""
186 with suppress(AttributeError, GitCommandError):
187 self.repo.remotes.origin.fetch()
188 with suppress(AttributeError, GitCommandError):
189 self.repo.remotes.upstream.fetch()
191 @property
192 def latest_tag(self) -> TagReference:
193 """Latest tag."""
194 return sorted(self.repo.tags, key=lambda t: t.commit.committed_datetime)[-1]
196 def lock(self) -> bool:
197 """Lock project."""
198 return self.LOCKS[self].acquire(blocking=False)
200 def unlock(self) -> None:
201 """Unlock project."""
202 self.LOCKS[self].release()