Coverage for src/shellman/context.py: 96.05%

52 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-27 14:35 +0100

1"""Jinja-context related utilities.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import json 

7import os 

8from typing import TYPE_CHECKING, Any 

9 

10if TYPE_CHECKING: 

11 import argparse 

12 from collections.abc import Sequence 

13 

14ENV_VAR_PREFIX = "SHELLMAN_CONTEXT_" 

15DEFAULT_JSON_FILE = ".shellman.json" 

16 

17 

18def _get_cli_context(args: Sequence[str]) -> dict: 

19 context: dict[str, Any] = {} 

20 if args: 

21 for context_arg in args: 

22 if not context_arg: 

23 continue 

24 if context_arg[0] == "{": 

25 context.update(json.loads(context_arg)) 

26 elif "=" in context_arg: 

27 name, value = context_arg.split("=", 1) 

28 if "." in name: 

29 name_dict: dict[str, Any] = {} 

30 d = name_dict 

31 parts = name.split(".") 

32 for name_part in parts[1:-1]: 

33 d[name_part] = d = {} 

34 d[parts[-1]] = value 

35 context[parts[0]] = name_dict 

36 else: 

37 context[name] = value 

38 # else invalid arg 

39 return context 

40 

41 

42def _get_env_context() -> dict: 

43 context = {} 

44 for env_name, env_value in os.environ.items(): 

45 if env_name.startswith(ENV_VAR_PREFIX): 

46 context_var_name = env_name[len(ENV_VAR_PREFIX) :].lower() 

47 context[context_var_name] = env_value 

48 return context 

49 

50 

51def _get_file_context(file: str) -> dict: 

52 with open(file) as stream: 

53 return json.load(stream) 

54 

55 

56def _get_context(args: argparse.Namespace) -> dict: 

57 context = {} 

58 

59 if args.context_file: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 context.update(_get_file_context(args.context_file)) 

61 else: 

62 with contextlib.suppress(OSError): 

63 context.update(_get_file_context(DEFAULT_JSON_FILE)) 

64 

65 _update(context, _get_env_context()) 

66 _update(context, _get_cli_context(args.context)) 

67 

68 return context 

69 

70 

71def _update(base: dict, added: dict) -> dict: 

72 for key, value in added.items(): 

73 if isinstance(value, dict): 

74 base[key] = _update(base.get(key, {}), value) 

75 else: 

76 base[key] = value 

77 return base