Coverage for src/duty/validation.py: 96.06%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-17 17:18 +0200

1"""This module contains logic used to validate parameters passed to duties. 

2 

3We validate the parameters before running the duties, 

4effectively checking all CLI arguments and failing early 

5if they are incorrect. 

6""" 

7 

8from __future__ import annotations 

9 

10import sys 

11import textwrap 

12from contextlib import suppress 

13from functools import cached_property, partial 

14from inspect import Parameter, Signature, signature 

15from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, get_args, get_origin 

16 

17if TYPE_CHECKING: 

18 from collections.abc import Sequence 

19 

20# YORE: EOL 3.9: Replace block with lines 6-13. 

21if sys.version_info < (3, 10): 

22 from eval_type_backport import eval_type_backport as eval_type 

23 

24 union_types = (Union,) 

25else: 

26 from types import UnionType 

27 from typing import _eval_type # type: ignore[attr-defined] 

28 

29 if sys.version_info >= (3, 13): 

30 eval_type = partial(_eval_type, type_params=None) 

31 else: 

32 eval_type = _eval_type 

33 union_types = (Union, UnionType) 

34 

35 

36def to_bool(value: str) -> bool: 

37 """Convert a string to a boolean. 

38 

39 Parameters: 

40 value: The string to convert. 

41 

42 Returns: 

43 True or False. 

44 """ 

45 return value.lower() not in {"", "0", "no", "n", "false", "off"} 

46 

47 

48def cast_arg(arg: Any, annotation: Any) -> Any: 

49 """Cast an argument using a type annotation. 

50 

51 Parameters: 

52 arg: The argument value. 

53 annotation: A type annotation. 

54 

55 Returns: 

56 The cast value. 

57 """ 

58 if annotation is Parameter.empty: 

59 return arg 

60 if annotation is bool: 

61 annotation = to_bool 

62 if get_origin(annotation) in union_types: 

63 for sub_annotation in get_args(annotation): 63 ↛ 68line 63 didn't jump to line 68 because the loop on line 63 didn't complete

64 if sub_annotation is type(None): 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true

65 continue 

66 with suppress(Exception): 

67 return cast_arg(arg, sub_annotation) 

68 try: 

69 return annotation(arg) 

70 except Exception: # noqa: BLE001 

71 return arg 

72 

73 

74class ParamsCaster: 

75 """A helper class to cast parameters based on a function's signature annotations.""" 

76 

77 def __init__(self, signature: Signature) -> None: 

78 """Initialize the object. 

79 

80 Parameters: 

81 signature: The signature to use to cast arguments. 

82 """ 

83 self.params_dict = signature.parameters 

84 self.params_list = list(self.params_dict.values()) 

85 

86 @cached_property 

87 def var_positional_position(self) -> int: 

88 """Give the position of the variable positional parameter in the signature. 

89 

90 Returns: 

91 The position of the variable positional parameter. 

92 """ 

93 for pos, param in enumerate(self.params_list): 

94 if param.kind is Parameter.VAR_POSITIONAL: 

95 return pos 

96 return -1 

97 

98 @cached_property 

99 def has_var_positional(self) -> bool: 

100 """Tell if there is a variable positional parameter. 

101 

102 Returns: 

103 True or False. 

104 """ 

105 return self.var_positional_position >= 0 

106 

107 @cached_property 

108 def var_positional_annotation(self) -> Any: 

109 """Give the variable positional parameter (`*args`) annotation if any. 

110 

111 Returns: 

112 The variable positional parameter annotation. 

113 """ 

114 return self.params_list[self.var_positional_position].annotation 

115 

116 @cached_property 

117 def var_keyword_annotation(self) -> Any: 

118 """Give the variable keyword parameter (`**kwargs`) annotation if any. 

119 

120 Returns: 

121 The variable keyword parameter annotation. 

122 """ 

123 for param in self.params_list: 123 ↛ 126line 123 didn't jump to line 126 because the loop on line 123 didn't complete

124 if param.kind is Parameter.VAR_KEYWORD: 

125 return param.annotation 

126 return Parameter.empty 

127 

128 def annotation_at_pos(self, pos: int) -> Any: 

129 """Give the annotation for the parameter at the given position. 

130 

131 Parameters: 

132 pos: The position of the parameter. 

133 

134 Returns: 

135 The positional parameter annotation. 

136 """ 

137 return self.params_list[pos].annotation 

138 

139 def eaten_by_var_positional(self, pos: int) -> bool: 

140 """Tell if the parameter at this position is eaten by a variable positional parameter. 

141 

142 Parameters: 

143 pos: The position of the parameter. 

144 

145 Returns: 

146 Whether the parameter is eaten. 

147 """ 

148 return self.has_var_positional and pos >= self.var_positional_position 

149 

150 def cast_posarg(self, pos: int, arg: Any) -> Any: 

151 """Cast a positional argument. 

152 

153 Parameters: 

154 pos: The position of the argument in the signature. 

155 arg: The argument value. 

156 

157 Returns: 

158 The cast value. 

159 """ 

160 if self.eaten_by_var_positional(pos): 

161 return cast_arg(arg, self.var_positional_annotation) 

162 return cast_arg(arg, self.annotation_at_pos(pos)) 

163 

164 def cast_kwarg(self, name: str, value: Any) -> Any: 

165 """Cast a keyword argument. 

166 

167 Parameters: 

168 name: The name of the argument in the signature. 

169 value: The argument value. 

170 

171 Returns: 

172 The cast value. 

173 """ 

174 if name in self.params_dict: 

175 return cast_arg(value, self.params_dict[name].annotation) 

176 return cast_arg(value, self.var_keyword_annotation) 

177 

178 def cast(self, *args: Any, **kwargs: Any) -> tuple[Sequence, dict[str, Any]]: 

179 """Cast all positional and keyword arguments. 

180 

181 Parameters: 

182 *args: The positional arguments. 

183 **kwargs: The keyword arguments. 

184 

185 Returns: 

186 The cast arguments. 

187 """ 

188 positional = tuple(self.cast_posarg(pos, arg) for pos, arg in enumerate(args)) 

189 keyword = {name: self.cast_kwarg(name, value) for name, value in kwargs.items()} 

190 return positional, keyword 

191 

192 

193def _get_params_caster(func: Callable, *args: Any, **kwargs: Any) -> ParamsCaster: 

194 duties_module = sys.modules[func.__module__] 

195 exec_globals = dict(duties_module.__dict__) 

196 eval_str = False 

197 for name in list(exec_globals.keys()): 

198 if exec_globals[name] is annotations: 

199 eval_str = True 

200 del exec_globals[name] 

201 break 

202 exec_globals["__context_above"] = {} 

203 

204 # Don't keep first parameter: context. 

205 params = list(signature(func).parameters.values())[1:] 

206 params_no_types = [Parameter(param.name, param.kind, default=param.default) for param in params] 

207 code_sig = Signature(parameters=params_no_types) 

208 if eval_str: 

209 params_types = [ 

210 Parameter( 

211 param.name, 

212 param.kind, 

213 default=param.default, 

214 annotation=( 

215 eval_type( 

216 ForwardRef(param.annotation) if isinstance(param.annotation, str) else param.annotation, 

217 exec_globals, 

218 {}, 

219 ) 

220 if param.annotation is not Parameter.empty 

221 else type(param.default) 

222 ), 

223 ) 

224 for param in params 

225 ] 

226 else: 

227 params_types = params 

228 cast_sig = Signature(parameters=params_types) 

229 

230 code = f""" 

231 import inspect 

232 def {func.__name__}{code_sig}: ... 

233 __context_above['func'] = {func.__name__} 

234 """ 

235 

236 exec(textwrap.dedent(code), exec_globals) # noqa: S102 

237 func = exec_globals["__context_above"]["func"] 

238 

239 # Trigger TypeError early. 

240 func(*args, **kwargs) 

241 

242 return ParamsCaster(cast_sig) 

243 

244 

245def validate( 

246 func: Callable, 

247 *args: Any, 

248 **kwargs: Any, 

249) -> tuple[Sequence, dict[str, Any]]: 

250 """Validate positional and keyword arguments against a function. 

251 

252 First we clone the function, removing the first parameter (the context) 

253 and the body, to fail early with a `TypeError` if the arguments 

254 are incorrect: not enough, too much, in the wrong order, etc. 

255 

256 Then we cast all the arguments using the function's signature 

257 and we return them. 

258 

259 Parameters: 

260 func: The function to copy. 

261 *args: The positional arguments. 

262 **kwargs: The keyword arguments. 

263 

264 Returns: 

265 The casted arguments. 

266 """ 

267 return _get_params_caster(func, *args, **kwargs).cast(*args, **kwargs)