Coverage for src/fluree_py/query/select/pydantic/type_checker.py: 96%

54 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-02 03:03 +0000

1from types import UnionType 

2from pydantic import BaseModel 

3 

4 

5from typing import ( 

6 Any, 

7 List, 

8 Type, 

9 TypeGuard, 

10 Union, 

11 get_args, 

12 get_origin, 

13 get_type_hints, 

14) 

15 

16 

17class TypeChecker: 

18 """Handles type checking and validation for Pydantic models. 

19 

20 This class provides utility methods for checking various types that can appear 

21 in Pydantic models, such as lists, dictionaries, and nested models. 

22 """ 

23 

24 @classmethod 

25 def is_list_type(cls, field_type: Any) -> TypeGuard[Type[List[Any]]]: 

26 """Check if a type is a list type.""" 

27 return hasattr(field_type, "__origin__") and field_type.__origin__ is list 

28 

29 @classmethod 

30 def is_dict_type(cls, field_type: Any) -> TypeGuard[Type[dict[str, Any]]]: 

31 """Check if a type is a dict type.""" 

32 return hasattr(field_type, "__origin__") and field_type.__origin__ is dict 

33 

34 @classmethod 

35 def is_tuple_type(cls, field_type: Any) -> TypeGuard[Type[tuple[Any, ...]]]: 

36 """Check if a type is a tuple type.""" 

37 return hasattr(field_type, "__origin__") and field_type.__origin__ is tuple 

38 

39 @classmethod 

40 def is_primitive_type(cls, field_type: Any) -> bool: 

41 """Check if a type is a primitive type (str, int, float, bool).""" 

42 return isinstance(field_type, type) and field_type in {str, int, float, bool} 

43 

44 @classmethod 

45 def is_id_field(cls, field_name: str) -> bool: 

46 """Check if a field name is the id field.""" 

47 return field_name == "id" 

48 

49 @classmethod 

50 def dict_max_depth(cls, field_type: Any, depth: int = 0) -> int: 

51 """Recursively count dictionary nesting depth.""" 

52 while TypeChecker.is_dict_type(field_type): 

53 args = get_args(field_type) 

54 if not args or len(args) < 2: 

55 break 

56 field_type = args[1] # Move to the value type 

57 depth += 1 

58 return depth 

59 

60 @classmethod 

61 def is_base_model(cls, field_type: Any) -> TypeGuard[Type[BaseModel]]: 

62 """Check if a type is a BaseModel.""" 

63 return isinstance(field_type, type) and issubclass(field_type, BaseModel) 

64 

65 @classmethod 

66 def get_real_type(cls, field_type: Any) -> Any: 

67 """Get the real type from a potentially optional/union type.""" 

68 origin = get_origin(field_type) 

69 if origin in {Union, UnionType}: 

70 types = [t for t in get_args(field_type) if t is not type(None)] 

71 if types: 

72 return next((t for t in types if cls.is_list_type(t)), types[0]) 

73 return field_type 

74 

75 @classmethod 

76 def check_model_has_id(cls, model: Type[BaseModel]) -> bool: 

77 """Check if a model has an id field in its type hints.""" 

78 return "id" in get_type_hints(model, include_extras=True) 

79 

80 @classmethod 

81 def has_model_config(cls, model: Type[BaseModel]) -> bool: 

82 """Check if a model has a model_config attribute.""" 

83 return hasattr(model, "model_config") 

84 

85 @classmethod 

86 def check_model_requires_id(cls, model: Type[BaseModel]) -> bool: 

87 """Check if a model requires an id field based on its configuration.""" 

88 if not cls.has_model_config(model): 

89 return True 

90 

91 config = model.model_config 

92 if not config: 

93 return True 

94 

95 extra = config.get("extra", "ignore") 

96 return extra in ("forbid", None)