"""written under MIT Licence, Michael Feil 2023."""

import asyncio
from logging import getLogger
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator

__all__ = ["InfinityEmbeddingsLocal"]

logger = getLogger(__name__)


class InfinityEmbeddingsLocal(BaseModel, Embeddings):
    """Optimized Infinity embedding models.

    https://github.com/michaelfeil/infinity
    This class deploys a local Infinity instance to embed text.
    The class requires async usage.

    Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity


    Example:
        .. code-block:: python

            from langchain_community.embeddings import InfinityEmbeddingsLocal
            async with InfinityEmbeddingsLocal(
                model="BAAI/bge-small-en-v1.5",
                revision=None,
                device="cpu",
            ) as embedder:
                embeddings = await engine.aembed_documents(["text1", "text2"])
    """

    model: str
    "Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5"

    revision: Optional[str] = None
    "Model version, the commit hash from huggingface"

    batch_size: int = 32
    "Internal batch size for inference, e.g. 32"

    device: str = "auto"
    "Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'"

    backend: str = "torch"
    "Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)"
    " or 'optimum' for onnx/tensorrt"

    model_warmup: bool = True
    "Warmup the model with the max batch size."

    engine: Any = None  #: :meta private:
    """Infinity's AsyncEmbeddingEngine."""

    # LLM call kwargs
    class Config:
        extra = "forbid"

    @root_validator(pre=False, skip_on_failure=True)
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""

        try:
            from infinity_emb import AsyncEmbeddingEngine  # type: ignore
        except ImportError:
            raise ImportError(
                "Please install the "
                "`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
                "package to use the InfinityEmbeddingsLocal."
            )
        logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")

        values["engine"] = AsyncEmbeddingEngine(
            model_name_or_path=values["model"],
            device=values["device"],
            revision=values["revision"],
            model_warmup=values["model_warmup"],
            batch_size=values["batch_size"],
            engine=values["backend"],
        )
        return values

    async def __aenter__(self) -> None:
        """start the background worker.
        recommended usage is with the async with statement.

        async with InfinityEmbeddingsLocal(
            model="BAAI/bge-small-en-v1.5",
            revision=None,
            device="cpu",
        ) as embedder:
            embeddings = await engine.aembed_documents(["text1", "text2"])
        """
        await self.engine.__aenter__()

    async def __aexit__(self, *args: Any) -> None:
        """stop the background worker,
        required to free references to the pytorch model."""
        await self.engine.__aexit__(*args)

    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Async call out to Infinity's embedding endpoint.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        if not self.engine.running:
            logger.warning(
                "Starting Infinity engine on the fly. This is not recommended."
                "Please start the engine before using it."
            )
            async with self:
                # spawning threadpool for multithreaded encode, tokenization
                embeddings, _ = await self.engine.embed(texts)
            # stopping threadpool on exit
            logger.warning("Stopped infinity engine after usage.")
        else:
            embeddings, _ = await self.engine.embed(texts)
        return embeddings

    async def aembed_query(self, text: str) -> List[float]:
        """Async call out to Infinity's embedding endpoint.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        embeddings = await self.aembed_documents([text])
        return embeddings[0]

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        This method is async only.
        """
        logger.warning(
            "This method is async only. "
            "Please use the async version `await aembed_documents`."
        )
        return asyncio.run(self.aembed_documents(texts))

    def embed_query(self, text: str) -> List[float]:
        """ """
        logger.warning(
            "This method is async only."
            " Please use the async version `await aembed_query`."
        )
        return asyncio.run(self.aembed_query(text))
