U
    h,$                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZmZm	Z	m
Z
mZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ G d
d deZdS )    )annotationsN)Path)AnyDictListOptionalTupleUnion)CallbackManagerForRetrieverRun)Document)	SecretStr)BaseRetriever)convert_to_secret_strget_from_dict_or_envpre_initc                   @  s  e Zd ZU dZded< dZded< G dd dZed<d	d
dddZe	d=d	dd dddZ
e	d>dd	d dddZedddddZd?ddddd
dddZd d d!d"d#Zd$d%d
d&d'd(Zd)d
d*d+d,Zd$d$d
d-d.d/Zd0d
d1d2d3Zd$d4dd5d6d7d8Zd$d
d9d:d;ZdS )@NeuralDBRetrieverz0Document retriever that uses ThirdAI's NeuralDB.r   thirdai_keyNr   dbc                   @  s   e Zd ZdZdZdS )zNeuralDBRetriever.ConfigZforbidTN)__name__
__module____qualname__extraZunderscore_attrs_are_private r   r   S/tmp/pip-unpacked-wheel-9gdii04g/langchain_community/retrievers/thirdai_neuraldb.pyConfig   s   r   zOptional[str]None)r   returnc                 C  sR   z0ddl m} tjd || p*td W n tk
rL   tdY nX d S )Nr   )	licensingzthirdai.neural_dbTHIRDAI_KEYz{Could not import thirdai python package and neuraldb dependencies. Please install it with `pip install thirdai[neural_db]`.)	thirdair   	importlibutil	find_specactivateosgetenvImportError)r   r   r   r   r   _verify_thirdai_library   s    z)NeuralDBRetriever._verify_thirdai_librarydict)r   model_kwargsr   c                 K  s*   t | ddlm} | ||jf |dS )a  
        Create a NeuralDBRetriever from scratch.

        To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
        API key, or pass ``thirdai_key`` as a named parameter.

        Example:
            .. code-block:: python

                from langchain_community.retrievers import NeuralDBRetriever

                retriever = NeuralDBRetriever.from_scratch(
                    thirdai_key="your-thirdai-key",
                )

                retriever.insert([
                    "/path/to/doc.pdf",
                    "/path/to/doc.docx",
                    "/path/to/doc.csv",
                ])

                documents = retriever.invoke("AI-driven music therapy")
        r   	neural_dbr   r   )r   r'   r   r+   NeuralDB)clsr   r)   ndbr   r   r   from_scratch*   s    
zNeuralDBRetriever.from_scratchzUnion[str, Path])
checkpointr   r   c                 C  s*   t | ddlm} | ||j|dS )a!  
        Create a NeuralDBRetriever with a base model from a saved checkpoint

        To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
        API key, or pass ``thirdai_key`` as a named parameter.

        Example:
            .. code-block:: python

                from langchain_community.retrievers import NeuralDBRetriever

                retriever = NeuralDBRetriever.from_checkpoint(
                    checkpoint="/path/to/checkpoint.ndb",
                    thirdai_key="your-thirdai-key",
                )

                retriever.insert([
                    "/path/to/doc.pdf",
                    "/path/to/doc.docx",
                    "/path/to/doc.csv",
                ])

                documents = retriever.invoke("AI-driven music therapy")
        r   r*   r,   )r   r'   r   r+   r-   from_checkpoint)r.   r1   r   r/   r   r   r   r2   L   s    
z!NeuralDBRetriever.from_checkpointr   )valuesr   c                 C  s   t t|dd|d< |S )z'Validate ThirdAI environment variables.r   r   )r   r   )r.   r3   r   r   r   validate_environmentso   s    z'NeuralDBRetriever.validate_environmentsTz	List[Any]bool)sourcestrain	fast_modekwargsr   c                 K  s(   |  |}| jjf |||d| dS )as  Inserts files / document sources into the retriever.

        Args:
            train: When True this means that the underlying model in the
            NeuralDB will undergo unsupervised pretraining on the inserted files.
            Defaults to True.
            fast_mode: Much faster insertion with a slight drop in performance.
            Defaults to True.
        )r6   r7   Zfast_approximationN)_preprocess_sourcesr   insert)selfr6   r7   r8   r9   r   r   r   r;   {   s    
zNeuralDBRetriever.insertlist)r6   r   c                 C  s   ddl m} |s|S g }|D ]}t|ts6|| q| drV||| q| drv||| q| dr||	| qt
d| dq|S )zChecks if the provided sources are string paths. If they are, convert
        to NeuralDB document objects.

        Args:
            sources: list of either string paths to PDF, DOCX or CSV files, or
            NeuralDB document objects.
        r   r*   z.pdfz.docxz.csvzCould not automatically load z. Only files with .pdf, .docx, or .csv extensions can be loaded automatically. For other formats, please use the appropriate document object from the ThirdAI library.)r   r+   
isinstancestrappendlowerendswithZPDFZDOCXZCSVRuntimeError)r<   r6   r/   Zpreprocessed_sourcesdocr   r   r   r:      s"    

z%NeuralDBRetriever._preprocess_sourcesr?   int)querydocument_idr   c                 C  s   | j || dS )a!  The retriever upweights the score of a document for a specific query.
        This is useful for fine-tuning the retriever to user behavior.

        Args:
            query: text to associate with `document_id`
            document_id: id of the document to associate query with.
        N)r   Ztext_to_result)r<   rF   rG   r   r   r   upvote   s    zNeuralDBRetriever.upvotezList[Tuple[str, int]])query_id_pairsr   c                 C  s   | j | dS )a  Given a batch of (query, document id) pairs, the retriever upweights
        the scores of the document for the corresponding queries.
        This is useful for fine-tuning the retriever to user behavior.

        Args:
            query_id_pairs: list of (query, document id) pairs. For each pair in
            this list, the model will upweight the document id for the query.
        N)r   Ztext_to_result_batch)r<   rI   r   r   r   upvote_batch   s    	zNeuralDBRetriever.upvote_batch)sourcetargetr   c                 C  s   | j || dS )a=  The retriever associates a source phrase with a target phrase.
        When the retriever sees the source phrase, it will also consider results
        that are relevant to the target phrase.

        Args:
            source: text to associate to `target`.
            target: text to associate `source` to.
        N)r   	associate)r<   rK   rL   r   r   r   rM      s    	zNeuralDBRetriever.associatezList[Tuple[str, str]])
text_pairsr   c                 C  s   | j | dS )a.  Given a batch of (source, target) pairs, the retriever associates
        each source phrase with the corresponding target phrase.

        Args:
            text_pairs: list of (source, target) text pairs. For each pair in
            this list, the source will be associated with the target.
        N)r   associate_batch)r<   rN   r   r   r   rO      s    z!NeuralDBRetriever.associate_batchr
   zList[Document])rF   run_managerr9   r   c              
   K  sn   z6d|krd|d< | j jf d|i|}dd |D W S  tk
rh } ztd| |W 5 d}~X Y nX dS )zRetrieve {top_k} contexts with your retriever for a given query

        Args:
            query: Query to submit to the model
            top_k: The max number of context results to retrieve. Defaults to 10.
        Ztop_k
   rF   c                 S  s8   g | ]0}t |j|j|j|j|j|j|d ddqS )   )id
upvote_idsrK   metadatascorecontext)Zpage_contentrU   )r   textrS   rT   rK   rU   rV   rW   ).0refr   r   r   
<listcomp>   s   z=NeuralDBRetriever._get_relevant_documents.<locals>.<listcomp>z"Error while retrieving documents: N)r   search	Exception
ValueError)r<   rF   rP   r9   Z
referenceser   r   r   _get_relevant_documents   s    	z)NeuralDBRetriever._get_relevant_documents)pathr   c                 C  s   | j | dS )zSaves a NeuralDB instance to disk. Can be loaded into memory by
        calling NeuralDB.from_checkpoint(path)

        Args:
            path: path on disk to save the NeuralDB instance to.
        N)r   save)r<   ra   r   r   r   rb      s    zNeuralDBRetriever.save)N)N)N)TT)r   r   r   __doc____annotations__r   r   staticmethodr'   classmethodr0   r2   r   r4   r;   r:   rH   rJ   rM   rO   r`   rb   r   r   r   r   r      s0   
 ! "   

r   )
__future__r   r    r$   pathlibr   typingr   r   r   r   r   r	   Zlangchain_core.callbacksr
   Zlangchain_core.documentsr   Zlangchain_core.pydantic_v1r   Zlangchain_core.retrieversr   Zlangchain_core.utilsr   r   r   r   r   r   r   r   <module>   s    