Coverage for src/fluree_py/query/select/pydantic/type_checker.py: 96%
54 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 03:03 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-02 03:03 +0000
1from types import UnionType
2from pydantic import BaseModel
5from typing import (
6 Any,
7 List,
8 Type,
9 TypeGuard,
10 Union,
11 get_args,
12 get_origin,
13 get_type_hints,
14)
17class TypeChecker:
18 """Handles type checking and validation for Pydantic models.
20 This class provides utility methods for checking various types that can appear
21 in Pydantic models, such as lists, dictionaries, and nested models.
22 """
24 @classmethod
25 def is_list_type(cls, field_type: Any) -> TypeGuard[Type[List[Any]]]:
26 """Check if a type is a list type."""
27 return hasattr(field_type, "__origin__") and field_type.__origin__ is list
29 @classmethod
30 def is_dict_type(cls, field_type: Any) -> TypeGuard[Type[dict[str, Any]]]:
31 """Check if a type is a dict type."""
32 return hasattr(field_type, "__origin__") and field_type.__origin__ is dict
34 @classmethod
35 def is_tuple_type(cls, field_type: Any) -> TypeGuard[Type[tuple[Any, ...]]]:
36 """Check if a type is a tuple type."""
37 return hasattr(field_type, "__origin__") and field_type.__origin__ is tuple
39 @classmethod
40 def is_primitive_type(cls, field_type: Any) -> bool:
41 """Check if a type is a primitive type (str, int, float, bool)."""
42 return isinstance(field_type, type) and field_type in {str, int, float, bool}
44 @classmethod
45 def is_id_field(cls, field_name: str) -> bool:
46 """Check if a field name is the id field."""
47 return field_name == "id"
49 @classmethod
50 def dict_max_depth(cls, field_type: Any, depth: int = 0) -> int:
51 """Recursively count dictionary nesting depth."""
52 while TypeChecker.is_dict_type(field_type):
53 args = get_args(field_type)
54 if not args or len(args) < 2:
55 break
56 field_type = args[1] # Move to the value type
57 depth += 1
58 return depth
60 @classmethod
61 def is_base_model(cls, field_type: Any) -> TypeGuard[Type[BaseModel]]:
62 """Check if a type is a BaseModel."""
63 return isinstance(field_type, type) and issubclass(field_type, BaseModel)
65 @classmethod
66 def get_real_type(cls, field_type: Any) -> Any:
67 """Get the real type from a potentially optional/union type."""
68 origin = get_origin(field_type)
69 if origin in {Union, UnionType}:
70 types = [t for t in get_args(field_type) if t is not type(None)]
71 if types:
72 return next((t for t in types if cls.is_list_type(t)), types[0])
73 return field_type
75 @classmethod
76 def check_model_has_id(cls, model: Type[BaseModel]) -> bool:
77 """Check if a model has an id field in its type hints."""
78 return "id" in get_type_hints(model, include_extras=True)
80 @classmethod
81 def has_model_config(cls, model: Type[BaseModel]) -> bool:
82 """Check if a model has a model_config attribute."""
83 return hasattr(model, "model_config")
85 @classmethod
86 def check_model_requires_id(cls, model: Type[BaseModel]) -> bool:
87 """Check if a model requires an id field based on its configuration."""
88 if not cls.has_model_config(model):
89 return True
91 config = model.model_config
92 if not config:
93 return True
95 extra = config.get("extra", "ignore")
96 return extra in ("forbid", None)