Coverage for src/yore/_internal/config.py: 41.28%

79 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-19 16:19 +0100

1from __future__ import annotations 

2 

3import logging 

4import sys 

5from dataclasses import dataclass, fields 

6from dataclasses import field as dataclass_field 

7from pathlib import Path 

8from typing import TYPE_CHECKING, Any 

9from typing import Annotated as An 

10 

11from typing_extensions import Doc 

12 

13# DUE: EOL 3.10: Replace block with line 2. 

14if sys.version_info >= (3, 11): 

15 import tomllib 

16else: 

17 import tomli as tomllib 

18 

19if TYPE_CHECKING: 

20 from collections.abc import Callable, Mapping 

21 

22 

23_logger = logging.getLogger(__name__) 

24 

25 

26class Unset: 

27 """A sentinel value for unset configuration options.""" 

28 

29 def __init__( 

30 self, 

31 key: An[str, Doc("TOML key.")], 

32 transform: An[str | None, Doc("Name of the method to call to transform the config value.")] = None, 

33 ) -> None: 

34 self.key: An[str, Doc("TOML key.")] = key 

35 self.name: An[str, Doc("Transformed key name.")] = key.replace("-", "_").replace(".", "_") 

36 self.transform: An[str | None, Doc("Name of the method to call to transform the config value.")] = transform 

37 

38 def __bool__(self) -> bool: 

39 """An unset value always evaluates to False.""" 

40 return False 

41 

42 def __repr__(self) -> str: 

43 return f"<Unset({self.name!r})>" 

44 

45 def __str__(self) -> str: 

46 # The string representation is used in the CLI, to show the default values. 

47 return f"`{self.key}` config-value" 

48 

49 

50def config_field( 

51 key: An[str, Doc("Key within the config file.")], 

52 transform: An[str | None, Doc("Name of transformation method to apply.")] = None, 

53) -> An[Unset, Doc("Configuration field.")]: 

54 """Create a dataclass field with a TOML key.""" 

55 return dataclass_field(default=Unset(key, transform=transform)) 

56 

57 

58# DUE: EOL 3.9: Remove block. 

59_dataclass_opts: dict[str, bool] = {} 

60if sys.version_info >= (3, 10): 

61 _dataclass_opts["kw_only"] = True 

62 

63 

64# DUE: EOL 3.9: Replace `**_dataclass_opts` with `kw_only=True` within line. 

65@dataclass(**_dataclass_opts) 

66class Config: 

67 """Configuration for the insiders project.""" 

68 

69 prefix: An[list[str] | Unset, Doc("The prefix for Yore comments.")] = config_field("prefix") # noqa: RUF009 

70 diff_highlight: An[str | Unset, Doc("The command to highlight diffs.")] = config_field("diff.highlight") # noqa: RUF009 

71 

72 @classmethod 

73 def _get( 

74 cls, 

75 data: An[Mapping[str, Any], Doc("Data to get value from.")], 

76 *keys: An[str, Doc("Keys to access nested dictionary.")], 

77 default: An[Unset, Doc("Default value if key is not found.")], 

78 transform: An[Callable[[Any], Any] | None, Doc("Transformation function to apply to the value.")] = None, 

79 ) -> An[Any, Doc("Value from the nested dictionary.")]: 

80 """Get a value from a nested dictionary.""" 

81 for key in keys: 

82 if key not in data: 

83 return default 

84 data = data[key] 

85 if transform: 

86 return transform(data) 

87 return data 

88 

89 @classmethod 

90 def from_data( 

91 cls, 

92 data: An[Mapping[str, Any], Doc("Data to load configuration from.")], 

93 ) -> An[Config, Doc("Loaded configuration.")]: 

94 """Load configuration from data.""" 

95 # Check for unknown configuration keys. 

96 field_keys = [field.default.key for field in fields(cls)] # type: ignore[union-attr] 

97 unknown_keys = [] 

98 for top_level_key, top_level_value in data.items(): 

99 if isinstance(top_level_value, dict): 

100 for key in top_level_value.keys(): # noqa: SIM118 

101 final_key = f"{top_level_key}.{key}" 

102 if final_key not in field_keys: 

103 unknown_keys.append(final_key) 

104 elif top_level_key not in field_keys: 

105 unknown_keys.append(top_level_key) 

106 if unknown_keys: 

107 _logger.warning(f"Unknown configuration keys: {', '.join(unknown_keys)}") 

108 

109 # Create a configuration instance. 

110 return cls( 

111 **{ 

112 field.name: cls._get( 

113 data, 

114 *field.default.key.split("."), # type: ignore[union-attr] 

115 default=field.default, # type: ignore[arg-type] 

116 transform=getattr(cls, field.default.transform or "", None), # type: ignore[union-attr] 

117 ) 

118 for field in fields(cls) 

119 }, 

120 ) 

121 

122 @classmethod 

123 def from_file( 

124 cls, 

125 path: An[str | Path, Doc("Path to the configuration file.")], 

126 ) -> An[Config, Doc("Loaded configuration.")]: 

127 """Load configuration from a file.""" 

128 with open(path, "rb") as file: 

129 return cls.from_data(tomllib.load(file)) 

130 

131 @classmethod 

132 def from_pyproject( 

133 cls, 

134 path: An[str | Path, Doc("Path to the pyproject.toml file.")], 

135 ) -> An[Config, Doc("Loaded configuration.")]: 

136 """Load configuration from pyproject.toml.""" 

137 with open(path, "rb") as file: 

138 return cls.from_data(tomllib.load(file).get("tool", {}).get("yore", {})) 

139 

140 @classmethod 

141 def from_default_locations(cls) -> An[Config, Doc("Loaded configuration.")]: 

142 """Load configuration from the default locations.""" 

143 paths = ("config/yore.toml", "yore.toml", "pyproject.toml") 

144 cwd = Path.cwd() 

145 while True: 

146 for path in paths: 

147 if (cwd / path).exists(): 

148 if path == "pyproject.toml": 

149 return cls.from_pyproject(cwd / path) 

150 return cls.from_file(cwd / path) 

151 if cwd == cwd.parent: 

152 break 

153 cwd = cwd.parent 

154 return cls()