Coverage for src/devboard/projects.py: 39.52%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-19 20:21 +0100

1"""Git utilities.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import re 

7from collections import defaultdict 

8from contextlib import contextmanager, suppress 

9from dataclasses import dataclass 

10from pathlib import Path 

11from threading import Lock 

12from typing import ClassVar, Iterator 

13 

14from git import Commit, GitCommandError, Head, Repo, TagReference # type: ignore[attr-defined] 

15 

16 

17@dataclass 

18class Status: 

19 """Git status data.""" 

20 

21 added: list[Path] 

22 """Added files.""" 

23 deleted: list[Path] 

24 """Deleted files.""" 

25 modified: list[Path] 

26 """Modified files.""" 

27 renamed: list[Path] 

28 """Renamed files.""" 

29 typechanged: list[Path] 

30 """Type-changed files.""" 

31 untracked: list[Path] 

32 """Untracked files.""" 

33 

34 

35@dataclass(eq=True, order=True, frozen=True) 

36class Project: 

37 """A class representing development projects. 

38 

39 It is instantiated with a path, and then provides 

40 many utility properties and methods. 

41 """ 

42 

43 LOCKS: ClassVar[dict[Project, Lock]] = defaultdict(Lock) 

44 

45 DEFAULT_BRANCHES: ClassVar[tuple[str, ...]] = ("main", "master") 

46 """Name of common default branches. Mainly useful to compute unreleased commits.""" 

47 path: Path 

48 """Path of the project on the file-system.""" 

49 

50 def __str__(self) -> str: 

51 return self.name 

52 

53 @property 

54 def repo(self) -> Repo: 

55 """GitPython's `Repo` object.""" 

56 return Repo(self.path) 

57 

58 @property 

59 def name(self) -> str: 

60 """Name of the project.""" 

61 return self.path.name 

62 

63 @property 

64 def is_dirty(self) -> bool: 

65 """Whether the project is in a "dirty" state (uncommitted modifications).""" 

66 return self.repo.is_dirty(untracked_files=True) 

67 

68 @property 

69 def status(self) -> Status: 

70 """Status of the project.""" 

71 diff = self.repo.index.diff(None) 

72 return Status( 

73 added=[Path(added) for added in diff.iter_change_type("A")], 

74 deleted=[Path(deleted.a_path) for deleted in diff.iter_change_type("D")], 

75 modified=[Path(modified.a_path) for modified in diff.iter_change_type("M")], 

76 renamed=[Path(renamed) for renamed in diff.iter_change_type("R")], 

77 typechanged=[Path(typechanged) for typechanged in diff.iter_change_type("T")], 

78 untracked=[Path(untracked) for untracked in self.repo.untracked_files], 

79 ) 

80 

81 @property 

82 def status_line(self) -> str: 

83 """Status of the project, as a string.""" 

84 st = self.status 

85 parts = [] 

86 if added := len(st.added): 

87 parts.append(f"{added}A") 

88 if deleted := len(st.deleted): 

89 parts.append(f"{deleted}D") 

90 if modified := len(st.modified): 

91 parts.append(f"{modified}M") 

92 if renamed := len(st.renamed): 

93 parts.append(f"{renamed}R") 

94 if typechanged := len(st.typechanged): 

95 parts.append(f"{typechanged}T") 

96 if untracked := len(st.untracked): 

97 parts.append(f"{untracked}U") 

98 return " ".join(parts) 

99 

100 def unpushed(self, remote: str = "origin") -> dict[str, int]: 

101 """Number of unpushed commits, per branch.""" 

102 result = {} 

103 for branch in self.repo.branches: # type: ignore[attr-defined] 

104 with contextlib.suppress(GitCommandError): 

105 result[branch.name] = len(list(self.repo.iter_commits(f"{remote}/{branch.name}..{branch.name}"))) 

106 return result 

107 

108 def unpulled(self, remote: str = "origin") -> dict[str, int]: 

109 """Number of unpulled commits, per branch.""" 

110 result = {} 

111 for branch in self.repo.branches: # type: ignore[attr-defined] 

112 with contextlib.suppress(GitCommandError): 

113 result[branch.name] = len(list(self.repo.iter_commits(f"{branch.name}..{remote}/{branch.name}"))) 

114 return result 

115 

116 @property 

117 def branch(self) -> Head: 

118 """Currently checked out branch.""" 

119 return self.repo.active_branch 

120 

121 @property 

122 def default_branch(self) -> str: 

123 """Default branch (or main branch), as checked out when cloning.""" 

124 for branch in self.DEFAULT_BRANCHES: 

125 if branch in self.repo.references: 

126 return branch 

127 try: 

128 origin = self.repo.git.remote("show", "origin") 

129 except GitCommandError as error: 

130 raise ValueError(f"Cannot infer default branch for repo {self.name}") from error 

131 if match := re.search(r"\s*HEAD branch:\s*(.*)", origin): 

132 return match.group(1) 

133 raise ValueError(f"Cannot infer default branch for repo {self.name}") 

134 

135 @contextmanager 

136 def checkout(self, branch: str | None) -> Iterator[None]: 

137 """Checkout branch, restore previous one when exiting.""" 

138 if not branch: 

139 yield 

140 return 

141 current = self.branch 

142 if branch == current: 

143 yield 

144 return 

145 self.repo.branches[branch].checkout() # type: ignore[index] 

146 try: 

147 yield 

148 finally: 

149 current.checkout() 

150 

151 def pull(self, branch: str | None = None) -> None: 

152 """Pull branch.""" 

153 with self.checkout(branch): 

154 self.repo.remotes.origin.pull() 

155 

156 def push(self, branch: str | None = None) -> None: 

157 """Push branch.""" 

158 with self.checkout(branch): 

159 self.repo.remotes.origin.push() 

160 

161 def delete(self, branch: str) -> None: 

162 """Delete branch.""" 

163 self.repo.delete_head(branch, force=True) 

164 

165 def unreleased(self, branch: str | None = None) -> list[Commit]: 

166 """List unreleased commits.""" 

167 commits = [] 

168 if branch is None: 

169 try: 

170 branch = self.default_branch 

171 except ValueError: 

172 return [] 

173 iterator = self.repo.iter_commits(branch) 

174 try: 

175 latest_tagged_commit = self.latest_tag.commit 

176 except IndexError: 

177 return list(iterator) 

178 for commit in iterator: 

179 if commit == latest_tagged_commit: 

180 break 

181 commits.append(commit) 

182 return commits 

183 

184 def fetch(self) -> None: 

185 """Fetch.""" 

186 with suppress(AttributeError, GitCommandError): 

187 self.repo.remotes.origin.fetch() 

188 with suppress(AttributeError, GitCommandError): 

189 self.repo.remotes.upstream.fetch() 

190 

191 @property 

192 def latest_tag(self) -> TagReference: 

193 """Latest tag.""" 

194 return sorted(self.repo.tags, key=lambda t: t.commit.committed_datetime)[-1] 

195 

196 def lock(self) -> bool: 

197 """Lock project.""" 

198 return self.LOCKS[self].acquire(blocking=False) 

199 

200 def unlock(self) -> None: 

201 """Unlock project.""" 

202 self.LOCKS[self].release()