Source code for xsdata.formats.dataclass.context

import sys
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Type

from xsdata.exceptions import XmlContextError
from xsdata.formats.bindings import T
from xsdata.formats.dataclass.compat import class_types
from xsdata.formats.dataclass.models.builders import XmlMetaBuilder
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.models.enums import DataType
from xsdata.utils.constants import return_input


[docs] class XmlContext: """ The service provider for binding operations' metadata. :param element_name_generator: Default element name generator :param attribute_name_generator: Default attribute name generator :param class_type: Default class type `dataclasses` :param models_package: Restrict auto locate to a specific package """ __slots__ = ( "element_name_generator", "attribute_name_generator", "class_type", "cache", "xsi_cache", "sys_modules", "models_package", ) def __init__( self, element_name_generator: Callable = return_input, attribute_name_generator: Callable = return_input, class_type: str = "dataclasses", models_package: Optional[str] = None, ): self.element_name_generator = element_name_generator self.attribute_name_generator = attribute_name_generator self.class_type = class_types.get_type(class_type) self.cache: Dict[Type, XmlMeta] = {} self.xsi_cache: Dict[str, List[Type]] = defaultdict(list) self.models_package = models_package self.sys_modules = 0 def reset(self): self.cache.clear() self.xsi_cache.clear() self.sys_modules = 0 def get_builder( self, globalns: Optional[Dict[str, Callable]] = None ) -> XmlMetaBuilder: return XmlMetaBuilder( class_type=self.class_type, element_name_generator=self.element_name_generator, attribute_name_generator=self.attribute_name_generator, globalns=globalns, )
[docs] def fetch( self, clazz: Type, parent_ns: Optional[str] = None, xsi_type: Optional[str] = None, ) -> XmlMeta: """ Fetch the model metadata of the given dataclass type, namespace and xsi type. :param clazz: The requested dataclass type :param parent_ns: The inherited parent namespace :param xsi_type: if present it means that the given clazz is derived and the lookup procedure needs to check and match a dataclass model to the qualified name instead """ meta = self.build(clazz, parent_ns) subclass = None if xsi_type and meta.target_qname != xsi_type: subclass = self.find_subclass(clazz, xsi_type) return self.build(subclass, parent_ns) if subclass else meta
[docs] def build_xsi_cache(self): """Index all imported dataclasses by their xsi:type qualified name.""" if len(sys.modules) == self.sys_modules: return self.xsi_cache.clear() builder = self.get_builder() for clazz in self.get_subclasses(object): if self.is_binding_model(clazz): meta = builder.build_class_meta(clazz) if meta.target_qname: self.xsi_cache[meta.target_qname].append(clazz) self.sys_modules = len(sys.modules)
def is_binding_model(self, clazz: Type[T]) -> bool: if not self.class_type.is_model(clazz): return False return not self.models_package or ( hasattr(clazz, "__module__") and isinstance(clazz.__module__, str) and clazz.__module__.startswith(self.models_package) )
[docs] def find_types(self, qname: str) -> List[Type[T]]: """ Find all classes that match the given xsi:type qname. - Ignores native schema types, xs:string, xs:float, xs:int, ... - Rebuild cache if new modules were imported since last run :param qname: Qualified name """ if not DataType.from_qname(qname): self.build_xsi_cache() if qname in self.xsi_cache: return self.xsi_cache[qname] return []
[docs] def find_type(self, qname: str) -> Optional[Type[T]]: """ Return the most recently imported class that matches the given xsi:type qname. :param qname: Qualified name """ types: List[Type] = self.find_types(qname) return types[-1] if types else None
[docs] def find_type_by_fields(self, field_names: Set[str]) -> Optional[Type[T]]: """ Find a dataclass from all the imported modules that matches the given list of field names. :param field_names: A unique list of field names """ def get_field_diff(clazz: Type) -> int: meta = self.cache[clazz] local_names = {var.local_name for var in meta.get_all_vars()} return len(local_names - field_names) self.build_xsi_cache() choices = [ (clazz, get_field_diff(clazz)) for types in self.xsi_cache.values() for clazz in types if self.local_names_match(field_names, clazz) ] choices.sort(key=lambda x: (x[1], x[0].__name__)) return choices[0][0] if len(choices) > 0 else None
[docs] def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]: """ Compare all classes that match the given xsi:type qname and return the first one that is either a subclass or shares the same parent class as the original class. :param clazz: The search dataclass type :param qname: Qualified name """ types: List[Type] = self.find_types(qname) for tp in types: # Why would an xml node with have an xsi:type that points # to parent class is beyond me but it happens, let's protect # against that scenario <node xsi:type="nodeAbstract" /> if issubclass(clazz, tp): continue for tp_mro in tp.__mro__: if tp_mro is not object and tp_mro in clazz.__mro__: return tp return None
[docs] def build( self, clazz: Type, parent_ns: Optional[str] = None, globalns: Optional[Dict[str, Callable]] = None, ) -> XmlMeta: """ Fetch from cache or build the binding metadata for the given class and parent namespace. :param clazz: A dataclass type :param parent_ns: The inherited parent namespace """ if clazz not in self.cache: builder = self.get_builder(globalns) self.cache[clazz] = builder.build(clazz, parent_ns) return self.cache[clazz]
[docs] def build_recursive(self, clazz: Type, parent_ns: Optional[str] = None): """Build the binding metadata for the given class and all of its dependencies.""" if clazz not in self.cache: meta = self.build(clazz, parent_ns) for var in meta.get_all_vars(): types = var.element_types if var.elements else var.types for tp in types: if self.class_type.is_model(tp): self.build_recursive(tp, meta.namespace)
def local_names_match(self, names: Set[str], clazz: Type) -> bool: try: meta = self.build(clazz) local_names = {var.local_name for var in meta.get_all_vars()} return not names.difference(local_names) except (XmlContextError, NameError, TypeError): # The dataclass includes unsupported typing annotations # Let's remove it from xsi_cache builder = self.get_builder() target_qname = builder.build_class_meta(clazz).target_qname if target_qname and target_qname in self.xsi_cache: self.xsi_cache[target_qname].remove(clazz) return False
[docs] @classmethod def is_derived(cls, obj: Any, clazz: Type) -> bool: """ Return whether the given obj is derived from the given dataclass type. :param obj: A dataclass instance :param clazz: A dataclass type """ if obj is None: return False if isinstance(obj, clazz): return True return any(x is not object and isinstance(obj, x) for x in clazz.__bases__)
@classmethod def get_subclasses(cls, clazz: Type): try: for subclass in clazz.__subclasses__(): yield from cls.get_subclasses(subclass) yield subclass except TypeError: pass