Coverage for src/duty/validation.py: 96.06%
95 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-17 17:18 +0200
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-17 17:18 +0200
1"""This module contains logic used to validate parameters passed to duties.
3We validate the parameters before running the duties,
4effectively checking all CLI arguments and failing early
5if they are incorrect.
6"""
8from __future__ import annotations
10import sys
11import textwrap
12from contextlib import suppress
13from functools import cached_property, partial
14from inspect import Parameter, Signature, signature
15from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, get_args, get_origin
17if TYPE_CHECKING:
18 from collections.abc import Sequence
20# YORE: EOL 3.9: Replace block with lines 6-13.
21if sys.version_info < (3, 10):
22 from eval_type_backport import eval_type_backport as eval_type
24 union_types = (Union,)
25else:
26 from types import UnionType
27 from typing import _eval_type # type: ignore[attr-defined]
29 if sys.version_info >= (3, 13):
30 eval_type = partial(_eval_type, type_params=None)
31 else:
32 eval_type = _eval_type
33 union_types = (Union, UnionType)
36def to_bool(value: str) -> bool:
37 """Convert a string to a boolean.
39 Parameters:
40 value: The string to convert.
42 Returns:
43 True or False.
44 """
45 return value.lower() not in {"", "0", "no", "n", "false", "off"}
48def cast_arg(arg: Any, annotation: Any) -> Any:
49 """Cast an argument using a type annotation.
51 Parameters:
52 arg: The argument value.
53 annotation: A type annotation.
55 Returns:
56 The cast value.
57 """
58 if annotation is Parameter.empty:
59 return arg
60 if annotation is bool:
61 annotation = to_bool
62 if get_origin(annotation) in union_types:
63 for sub_annotation in get_args(annotation): 63 ↛ 68line 63 didn't jump to line 68 because the loop on line 63 didn't complete
64 if sub_annotation is type(None): 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true
65 continue
66 with suppress(Exception):
67 return cast_arg(arg, sub_annotation)
68 try:
69 return annotation(arg)
70 except Exception: # noqa: BLE001
71 return arg
74class ParamsCaster:
75 """A helper class to cast parameters based on a function's signature annotations."""
77 def __init__(self, signature: Signature) -> None:
78 """Initialize the object.
80 Parameters:
81 signature: The signature to use to cast arguments.
82 """
83 self.params_dict = signature.parameters
84 self.params_list = list(self.params_dict.values())
86 @cached_property
87 def var_positional_position(self) -> int:
88 """Give the position of the variable positional parameter in the signature.
90 Returns:
91 The position of the variable positional parameter.
92 """
93 for pos, param in enumerate(self.params_list):
94 if param.kind is Parameter.VAR_POSITIONAL:
95 return pos
96 return -1
98 @cached_property
99 def has_var_positional(self) -> bool:
100 """Tell if there is a variable positional parameter.
102 Returns:
103 True or False.
104 """
105 return self.var_positional_position >= 0
107 @cached_property
108 def var_positional_annotation(self) -> Any:
109 """Give the variable positional parameter (`*args`) annotation if any.
111 Returns:
112 The variable positional parameter annotation.
113 """
114 return self.params_list[self.var_positional_position].annotation
116 @cached_property
117 def var_keyword_annotation(self) -> Any:
118 """Give the variable keyword parameter (`**kwargs`) annotation if any.
120 Returns:
121 The variable keyword parameter annotation.
122 """
123 for param in self.params_list: 123 ↛ 126line 123 didn't jump to line 126 because the loop on line 123 didn't complete
124 if param.kind is Parameter.VAR_KEYWORD:
125 return param.annotation
126 return Parameter.empty
128 def annotation_at_pos(self, pos: int) -> Any:
129 """Give the annotation for the parameter at the given position.
131 Parameters:
132 pos: The position of the parameter.
134 Returns:
135 The positional parameter annotation.
136 """
137 return self.params_list[pos].annotation
139 def eaten_by_var_positional(self, pos: int) -> bool:
140 """Tell if the parameter at this position is eaten by a variable positional parameter.
142 Parameters:
143 pos: The position of the parameter.
145 Returns:
146 Whether the parameter is eaten.
147 """
148 return self.has_var_positional and pos >= self.var_positional_position
150 def cast_posarg(self, pos: int, arg: Any) -> Any:
151 """Cast a positional argument.
153 Parameters:
154 pos: The position of the argument in the signature.
155 arg: The argument value.
157 Returns:
158 The cast value.
159 """
160 if self.eaten_by_var_positional(pos):
161 return cast_arg(arg, self.var_positional_annotation)
162 return cast_arg(arg, self.annotation_at_pos(pos))
164 def cast_kwarg(self, name: str, value: Any) -> Any:
165 """Cast a keyword argument.
167 Parameters:
168 name: The name of the argument in the signature.
169 value: The argument value.
171 Returns:
172 The cast value.
173 """
174 if name in self.params_dict:
175 return cast_arg(value, self.params_dict[name].annotation)
176 return cast_arg(value, self.var_keyword_annotation)
178 def cast(self, *args: Any, **kwargs: Any) -> tuple[Sequence, dict[str, Any]]:
179 """Cast all positional and keyword arguments.
181 Parameters:
182 *args: The positional arguments.
183 **kwargs: The keyword arguments.
185 Returns:
186 The cast arguments.
187 """
188 positional = tuple(self.cast_posarg(pos, arg) for pos, arg in enumerate(args))
189 keyword = {name: self.cast_kwarg(name, value) for name, value in kwargs.items()}
190 return positional, keyword
193def _get_params_caster(func: Callable, *args: Any, **kwargs: Any) -> ParamsCaster:
194 duties_module = sys.modules[func.__module__]
195 exec_globals = dict(duties_module.__dict__)
196 eval_str = False
197 for name in list(exec_globals.keys()):
198 if exec_globals[name] is annotations:
199 eval_str = True
200 del exec_globals[name]
201 break
202 exec_globals["__context_above"] = {}
204 # Don't keep first parameter: context.
205 params = list(signature(func).parameters.values())[1:]
206 params_no_types = [Parameter(param.name, param.kind, default=param.default) for param in params]
207 code_sig = Signature(parameters=params_no_types)
208 if eval_str:
209 params_types = [
210 Parameter(
211 param.name,
212 param.kind,
213 default=param.default,
214 annotation=(
215 eval_type(
216 ForwardRef(param.annotation) if isinstance(param.annotation, str) else param.annotation,
217 exec_globals,
218 {},
219 )
220 if param.annotation is not Parameter.empty
221 else type(param.default)
222 ),
223 )
224 for param in params
225 ]
226 else:
227 params_types = params
228 cast_sig = Signature(parameters=params_types)
230 code = f"""
231 import inspect
232 def {func.__name__}{code_sig}: ...
233 __context_above['func'] = {func.__name__}
234 """
236 exec(textwrap.dedent(code), exec_globals) # noqa: S102
237 func = exec_globals["__context_above"]["func"]
239 # Trigger TypeError early.
240 func(*args, **kwargs)
242 return ParamsCaster(cast_sig)
245def validate(
246 func: Callable,
247 *args: Any,
248 **kwargs: Any,
249) -> tuple[Sequence, dict[str, Any]]:
250 """Validate positional and keyword arguments against a function.
252 First we clone the function, removing the first parameter (the context)
253 and the body, to fail early with a `TypeError` if the arguments
254 are incorrect: not enough, too much, in the wrong order, etc.
256 Then we cast all the arguments using the function's signature
257 and we return them.
259 Parameters:
260 func: The function to copy.
261 *args: The positional arguments.
262 **kwargs: The keyword arguments.
264 Returns:
265 The casted arguments.
266 """
267 return _get_params_caster(func, *args, **kwargs).cast(*args, **kwargs)