from __future__ import annotations as _annotations

from collections.abc import Iterable, Iterator, Mapping
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from datetime import date, time, timedelta
from decimal import Decimal
from enum import Enum
from typing import Any, Literal
from uuid import UUID
from xml.etree import ElementTree

from pydantic import BaseModel
from pydantic_core import PydanticSerializationError

__all__ = ('format_as_xml',)

from pydantic.fields import ComputedFieldInfo, FieldInfo


def format_as_xml(
    obj: Any,
    root_tag: str | None = None,
    item_tag: str = 'item',
    none_str: str = 'null',
    indent: str | None = '  ',
    include_field_info: Literal['once'] | bool = False,
) -> str:
    """Format a Python object as XML.

    This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
    rather than JSON etc.

    Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `Decimal`, `date`, `datetime`, `time`, `timedelta`,
    `UUID`, `Enum`, `Mapping`, `Iterable`, `dataclass`, and `BaseModel`.

    Args:
        obj: Python Object to serialize to XML.
        root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
        item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
            for dataclasses and Pydantic models.
        none_str: String to use for `None` values.
        indent: Indentation string to use for pretty printing.
        include_field_info: Whether to include attributes like Pydantic `Field` attributes and dataclasses `field()`
            `metadata` as XML attributes. In both cases the allowed `Field` attributes and `field()` metadata keys are
            `title` and `description`. If a field is repeated in the data (e.g. in a list) by setting `once`
            the attributes are included only in the first occurrence of an XML element relative to the same field.

    Returns:
        XML representation of the object.

    Example:
    ```python {title="format_as_xml_example.py" lint="skip"}
    from pydantic_ai import format_as_xml

    print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
    '''
    <user>
      <name>John</name>
      <height>6</height>
      <weight>200</weight>
    </user>
    '''
    ```
    """
    el = _ToXml(
        data=obj,
        item_tag=item_tag,
        none_str=none_str,
        include_field_info=include_field_info,
    ).to_xml(root_tag)
    if root_tag is None and el.text is None:
        join = '' if indent is None else '\n'
        return join.join(_rootless_xml_elements(el, indent))
    else:
        if indent is not None:
            ElementTree.indent(el, space=indent)
        return ElementTree.tostring(el, encoding='unicode')


@dataclass
class _ToXml:
    data: Any
    item_tag: str
    none_str: str
    include_field_info: Literal['once'] | bool
    # a map of Pydantic and dataclasses Field paths to their metadata:
    # a field unique string representation and its class
    _fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(
        default_factory=dict[str, tuple[str, FieldInfo | ComputedFieldInfo]]
    )
    # keep track of fields we have extracted attributes from
    _included_fields: set[str] = field(default_factory=set[str])
    # keep track of class names for dataclasses and Pydantic models, that occur in lists
    _element_names: dict[str, str] = field(default_factory=dict[str, str])
    # flag for parsing dataclasses and Pydantic models once
    _is_info_extracted: bool = False
    _FIELD_ATTRIBUTES = ('title', 'description')

    def to_xml(self, tag: str | None = None) -> ElementTree.Element:
        return self._to_xml(value=self.data, path='', tag=tag)

    def _to_xml(self, value: Any, path: str, tag: str | None = None) -> ElementTree.Element:
        element = self._create_element(self.item_tag if tag is None else tag, path)
        if self._set_scalar_text(element, value):
            return element
        if isinstance(value, Mapping):
            if tag is None and path in self._element_names:
                element.tag = self._element_names[path]
            self._mapping_to_xml(element, value, path)  # pyright: ignore[reportUnknownArgumentType]
            return element
        if is_dataclass(value) and not isinstance(value, type):
            self._init_structure_info()
            if tag is None:
                element.tag = value.__class__.__name__
            self._mapping_to_xml(element, asdict(value), path)
            return element
        if isinstance(value, BaseModel):
            self._init_structure_info()
            if tag is None:
                element.tag = value.__class__.__name__
            # by dumping the model we loose all metadata in nested data structures,
            # but we have collected it when called _init_structure_info
            try:
                mapping = value.model_dump(mode='json')
            except PydanticSerializationError as e:
                raise TypeError(f'Unsupported type for XML formatting: {e}') from e
            self._mapping_to_xml(element, mapping, path)
            return element
        if isinstance(value, Iterable):
            for n, item in enumerate(value):  # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
                element.append(self._to_xml(value=item, path=f'{path}.[{n}]' if path else f'[{n}]'))
            return element
        raise TypeError(f'Unsupported type for XML formatting: {type(value)}')

    def _set_scalar_text(self, element: ElementTree.Element, value: Any) -> bool:
        """Set element.text for scalar types. Return True if handled, False otherwise."""
        if value is None:
            element.text = self.none_str
        elif isinstance(value, str):
            element.text = value.value if isinstance(value, Enum) else value
        elif isinstance(value, bytes | bytearray):
            element.text = value.decode(errors='ignore')
        elif isinstance(value, bool | int | float | Enum):
            element.text = str(value)
        elif isinstance(value, date | time):
            element.text = value.isoformat()
        elif isinstance(value, timedelta):
            element.text = str(value)
        elif isinstance(value, Decimal):
            element.text = str(value)
        elif isinstance(value, UUID):
            element.text = str(value)
        else:
            return False
        return True

    def _create_element(self, tag: str, path: str) -> ElementTree.Element:
        element = ElementTree.Element(tag)
        if path in self._fields_info:
            field_repr, field_info = self._fields_info[path]
            if self.include_field_info and self.include_field_info != 'once' or field_repr not in self._included_fields:
                field_attributes = self._extract_attributes(field_info)
                for k, v in field_attributes.items():
                    element.set(k, v)
                self._included_fields.add(field_repr)
        return element

    def _init_structure_info(self):
        """Create maps with all data information (fields info and class names), if not already created."""
        if not self._is_info_extracted:
            self._parse_data_structures(self.data)
            self._is_info_extracted = True

    def _mapping_to_xml(
        self,
        element: ElementTree.Element,
        mapping: Mapping[Any, Any],
        path: str = '',
    ) -> None:
        for key, value in mapping.items():
            if isinstance(key, int):
                key = str(key)
            elif not isinstance(key, str):
                raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
            element.append(self._to_xml(value=value, path=f'{path}.{key}' if path else key, tag=key))

    def _parse_data_structures(
        self,
        value: Any,
        path: str = '',
    ):
        """Parse data structures as dataclasses or Pydantic models to extract element names and attributes."""
        if value is None or isinstance(value, (str | int | float | date | time | timedelta | bytearray | bytes | bool)):
            return
        elif isinstance(value, Mapping):
            for k, v in value.items():  # pyright: ignore[reportUnknownVariableType]
                self._parse_data_structures(v, f'{path}.{k}' if path else f'{k}')
        elif is_dataclass(value) and not isinstance(value, type):
            self._element_names[path] = value.__class__.__name__
            for field in fields(value):
                new_path = f'{path}.{field.name}' if path else field.name
                if self.include_field_info and field.metadata:
                    attributes = {k: v for k, v in field.metadata.items() if k in self._FIELD_ATTRIBUTES}
                    if attributes:
                        field_repr = f'{value.__class__.__name__}.{field.name}'
                        self._fields_info[new_path] = (field_repr, FieldInfo(**attributes))
                self._parse_data_structures(getattr(value, field.name), new_path)
        elif isinstance(value, BaseModel):
            self._element_names[path] = value.__class__.__name__
            for model_fields in (value.__class__.model_fields, value.__class__.model_computed_fields):
                for field, info in model_fields.items():
                    new_path = f'{path}.{field}' if path else field
                    if self.include_field_info and (isinstance(info, ComputedFieldInfo) or not info.exclude):
                        field_repr = f'{value.__class__.__name__}.{field}'
                        self._fields_info[new_path] = (field_repr, info)
                    self._parse_data_structures(getattr(value, field), new_path)
        elif isinstance(value, Iterable):
            for n, item in enumerate(value):  # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
                new_path = f'{path}.[{n}]' if path else f'[{n}]'
                self._parse_data_structures(item, new_path)

    @classmethod
    def _extract_attributes(cls, info: FieldInfo | ComputedFieldInfo) -> dict[str, str]:
        return {attr: str(value) for attr in cls._FIELD_ATTRIBUTES if (value := getattr(info, attr, None)) is not None}


def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
    for sub_element in root:
        if indent is not None:
            ElementTree.indent(sub_element, space=indent)
        yield ElementTree.tostring(sub_element, encoding='unicode')
