U
    hCP                     @   sx  d dl Z d dlZd dlZd dlmZmZ d dlmZmZm	Z	m
Z
mZmZ d dlZd dlmZ d dlmZ d dlmZmZmZmZmZ dgZG dd	 d	eeZe	eef ed
ddZe	eef ed
ddZe	eef ed
ddZG dd deZG dd deZedddZ edddZ!edddZ"ee#dddZ$eee# edd d!Z%eed"d#d$Z&G d%d deZ'dS )&    N)ABCabstractmethod)AnyCallableDictListMappingOptional)CallbackManagerForLLMRun)LLM)	BaseModelFieldPrivateAttrroot_validator	validator
Databricksc                   @   s   e Zd ZU dZeed< eed< eeeedddZeeddd	Zeeed
ddZ	e
deeedef  edddZeedddZdS )_DatabricksClientBasez0A base JSON API client that talks to Databricks.api_url	api_token)methodurlrequestreturnc                 C   sH   dd| j  i}tj||||d}|js@td|j d|j | S )NAuthorizationzBearer )r   r   headersjsonzHTTP z error: )r   requestsr   ok
ValueErrorstatus_codetextr   )selfr   r   r   r   response r#   G/tmp/pip-unpacked-wheel-9gdii04g/langchain_community/llms/databricks.pyr      s       z_DatabricksClientBase.request)r   r   c                 C   s   |  d|d S )NGETr   )r!   r   r#   r#   r$   _get%   s    z_DatabricksClientBase._get)r   r   r   c                 C   s   |  d||S )NPOSTr&   )r!   r   r   r#   r#   r$   _post(   s    z_DatabricksClientBase._postN.r   transform_output_fnr   c                 C   s   d S Nr#   )r!   r   r+   r#   r#   r$   post+   s    z_DatabricksClientBase.postr   c                 C   s   dS )NFr#   r!   r#   r#   r$   llm0   s    z_DatabricksClientBase.llm)N)__name__
__module____qualname____doc__str__annotations__r   r   r'   r)   r   r	   r   r-   propertyboolr0   r#   r#   r#   r$   r      s   

  r   )r"   r   c                 C   s   | d d d S )Nchoicesr   r    r#   r"   r#   r#   r$   _transform_completions5   s    r;   c                 C   s   | d d d S )N
candidatesr   r    r#   r:   r#   r#   r$   _transform_llama2_chat9   s    r=   c                 C   s   | d d d d S )Nr9   r   messagecontentr#   r:   r#   r#   r$   _transform_chat=   s    r@   c                       s   e Zd ZU dZeed< eed< eed< dZeed< dZe	ed< dZ
ee ed	< ed
 fddZee	dddZeddeeef eeef dddZdeeedef  edddZ  ZS ) _DatabricksServingEndpointClientz:An API client that talks to a Databricks serving endpoint.hostendpoint_namedatabricks_uriNclientFexternal_or_foundationtaskdatac              
      s   t  jf | zddlm} || j| _W n, tk
rV } ztd|W 5 d }~X Y nX | j| j}|	dd
 dk| _| jd kr|	d| _d S )Nr   )get_deploy_clientzMFailed to create the client. Please install mlflow with `pip install mlflow`.Zendpoint_type )Zexternal_modelZfoundation_model_apirG   )super__init__Zmlflow.deploymentsrJ   rD   rE   ImportErrorZget_endpointrC   getlowerrF   rG   )r!   rI   rJ   eendpoint	__class__r#   r$   rM   K   s    
z)_DatabricksServingEndpointClient.__init__r.   c                 C   s
   | j dkS )N)llm/v1/chatllm/v1/completionsllama2/chat)rG   r/   r#   r#   r$   r0   `   s    z$_DatabricksServingEndpointClient.llmTprevaluesr   c                 C   s6   d|kr2|d }|d }d| d| d}||d< |S )Nr   rB   rC   https://z/serving-endpoints/z/invocationsr#   )clsr[   rB   rC   r   r#   r#   r$   set_api_urld   s    z,_DatabricksServingEndpointClient.set_api_url.r*   c                 C   s   | j rL| jj| j|d}|r$||S | jdkr6t|S | jdkrHt|S |S d|gi}| jj| j|d}|d }t|tr|d n|}| jdkrt	|S |r||S |S d S )N)rR   inputsrU   rV   Zdataframe_recordsZpredictionsr   rW   )
rF   rE   ZpredictrC   rG   r@   r;   
isinstancelistr=   )r!   r   r+   respZwrapped_requestr"   predspredr#   r#   r$   r-   m   s&    


 
z%_DatabricksServingEndpointClient.post)N)r1   r2   r3   r4   r5   r6   rE   r   rF   r8   rG   r	   rM   r7   r0   r   r   r^   r   r-   __classcell__r#   r#   rS   r$   rA   A   s$   
"	  rA   c                   @   st   e Zd ZU dZeed< eed< eed< eddeeef eeef ddd	Z	dee
edef  edddZd
S )#_DatabricksClusterDriverProxyClientzBAn API client that talks to a Databricks cluster driver proxy app.rB   
cluster_idcluster_driver_portTrX   rZ   c                 C   sB   d|kr>|d }|d }|d }d| d| d| }||d< |S )Nr   rB   rg   rh   r\   z/driver-proxy-api/o/0//r#   )r]   r[   rB   rg   portr   r#   r#   r$   r^      s    z/_DatabricksClusterDriverProxyClient.set_api_urlN.r*   c                 C   s   |  | j|}|r||S |S r,   )r)   r   )r!   r   r+   rb   r#   r#   r$   r-      s    z(_DatabricksClusterDriverProxyClient.post)N)r1   r2   r3   r4   r5   r6   r   r   r   r^   r	   r   r-   r#   r#   r#   r$   rf      s   
"
  rf   r.   c                  C   s6   zddl m}  |  W S  tk
r0   tdY nX dS )zgGet the notebook REPL context if running inside a Databricks notebook.
    Returns None otherwise.
    r   get_contextzBCannot access dbruntime, not running inside a Databricks notebook.N)Z!dbruntime.databricks_repl_contextrl   rN   rk   r#   r#   r$   get_repl_context   s    rm   c               
   C   sr   t d} | sXzt j} | s$tdW n0 tk
rV } ztd| W 5 d}~X Y nX | ddd} | S )z{Get the default Databricks workspace hostname.
    Raises an error if the hostname cannot be automatically determined.
    ZDATABRICKS_HOSTz(context doesn't contain browserHostName.zshost was not set and cannot be automatically inferred. Set environment variable 'DATABRICKS_HOST'. Received error: Nr\   zhttp://ri   )osgetenvrm   ZbrowserHostNamer   	Exceptionlstriprstrip)rB   rQ   r#   r#   r$   get_default_host   s    
rs   c               
   C   s`   t d } r| S zt j} | s(tdW n0 tk
rZ } ztd| W 5 d}~X Y nX | S )z{Get the default Databricks personal access token.
    Raises an error if the token cannot be automatically determined.
    ZDATABRICKS_TOKENz!context doesn't contain apiToken.zyapi_token was not set and cannot be automatically inferred. Set environment variable 'DATABRICKS_TOKEN'. Received error: N)rn   ro   rm   ZapiTokenr   rp   )r   rQ   r#   r#   r$   get_default_api_token   s    rt   )rI   r   c                 C   s"   t | tsdS d}tt|| S )zJChecks if a data is a valid hexadecimal string using a regular expression.Fz^[0-9a-fA-F]+$)r`   r5   r8   rematch)rI   patternr#   r#   r$   _is_hex_string   s    
rx   )rI   allow_dangerous_deserializationr   c              
   C   s   |st dzddl}W n0 tk
rH } zt d| W 5 d}~X Y nX z|t| W S  tk
r } zt d| W 5 d}~X Y nX dS )z3Loads a pickled function from a hexadecimal string.aW  This code relies on the pickle module. You will need to set allow_dangerous_deserialization=True if you want to opt-in to allow deserialization of data using pickle.Data can be compromised by a malicious actor if not handled properly to include a malicious payload that when deserialized with pickle can execute arbitrary code on your machine.r   N*Please install cloudpickle>=2.0.0. Error: zFFailed to load the pickled function from a hexadecimal string. Error: )r   cloudpicklerp   loadsbytesfromhex)rI   ry   r{   rQ   r#   r#   r$    _load_pickled_fn_from_hex_string   s    
 r   )fnr   c              
   C   s   zddl }W n0 tk
r< } ztd| W 5 d}~X Y nX z||  W S  tk
r~ } ztd| W 5 d}~X Y nX dS )z6Pickles a function and returns the hexadecimal string.r   Nrz   zFailed to pickle the function: )r{   rp   r   dumpshex)r   r{   rQ   r#   r#   r$   _pickle_fn_to_hex_string   s     r   c                       sb  e Zd ZU dZeedZeed< ee	dZ
eed< dZee ed< dZee ed< dZee ed< dZeeeef  ed	< dZee ed
< dZeedef  ed< dZeed< dZeed< dZeed< dZeee  ed< dZee ed< eedZeeef ed< dZee ed< dZ e!ed< e" Z#e$ed< G dd dZ%e&eeef dddZ'e(ddd eeeef ee d!d"d#Z)e(ddd eeeef ee d!d$d%Z*e(d	dd eeeef  eeeef  d&d'd(Z+ed) fd*d+Z,e&eeef dd,d-Z-e&e.eef dd.d/Z/e&edd0d1Z0d5eeee  ee1 eed2d3d4Z2  Z3S )6r   a	  Databricks serving endpoint or a cluster driver proxy app for LLM.

    It supports two endpoint types:

    * **Serving endpoint** (recommended for both production and development).
      We assume that an LLM was deployed to a serving endpoint.
      To wrap it as an LLM you must have "Can Query" permission to the endpoint.
      Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
      ``cluster_driver_port``.

      If the underlying model is a model registered by MLflow, the expected model
      signature is:

      * inputs::

          [{"name": "prompt", "type": "string"},
           {"name": "stop", "type": "list[string]"}]

      * outputs: ``[{"type": "string"}]``

      If the underlying model is an external or foundation model, the response from the
      endpoint is automatically transformed to the expected format unless
      ``transform_output_fn`` is provided.

    * **Cluster driver proxy app** (recommended for interactive development).
      One can load an LLM on a Databricks interactive cluster and start a local HTTP
      server on the driver node to serve the model at ``/`` using HTTP POST method
      with JSON input/output.
      Please use a port number between ``[3000, 8000]`` and let the server listen to
      the driver IP address or simply ``0.0.0.0`` instead of localhost only.
      To wrap it as an LLM you must have "Can Attach To" permission to the cluster.
      Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``.
      The expected server schema (using JSON schema) is:

      * inputs::

          {"type": "object",
           "properties": {
              "prompt": {"type": "string"},
              "stop": {"type": "array", "items": {"type": "string"}}},
           "required": ["prompt"]}`

      * outputs: ``{"type": "string"}``

    If the endpoint model signature is different or you want to set extra params,
    you can use `transform_input_fn` and `transform_output_fn` to apply necessary
    transformations before and after the query.
    )default_factoryrB   r   NrC   rg   rh   model_kwargstransform_input_fn.r+   
databricksrD   g        temperature   nstop
max_tokensextra_paramsrG   Fry   _clientc                   @   s   e Zd ZdZdZdS )zDatabricks.ConfigZforbidTN)r1   r2   r3   extraZunderscore_attrs_are_privater#   r#   r#   r$   Config  s   r   r.   c                 C   s6   | j | jd}| jr| j|d< | jd k	r2| j|d< |S )N)r   r   r   r   )r   r   r   r   )r!   paramsr#   r#   r$   _llm_params  s    


zDatabricks._llm_paramsT)always)vr[   r   c              
   C   s~   |r|d rt dnd|d r"d S |r*|S zt j }r>|W S t dW n0 tk
rx } zt d| W 5 d }~X Y nX d S )NrC   z-Cannot set both endpoint_name and cluster_id.z"Context doesn't contain clusterId.zuNeither endpoint_name nor cluster_id was set. And the cluster_id cannot be automatically determined. Received error: )r   rm   Z	clusterIdrp   )r]   r   r[   rQ   r#   r#   r$   set_cluster_id  s    
zDatabricks.set_cluster_idc                 C   sX   |r|d rt dn>|d r"d S |d kr4t dn t|dkrPt d| n|S d S )NrC   z6Cannot set both endpoint_name and cluster_driver_port.z<Must set cluster_driver_port to connect to a cluster driver.r   zInvalid cluster_driver_port: )r   int)r]   r   r[   r#   r#   r$   set_cluster_driver_port  s    
z"Databricks.set_cluster_driver_port)r   r   c                 C   s(   |r$d|kst dd|ks$t d|S )Npromptz*model_kwargs must not contain key 'prompt'r   z(model_kwargs must not contain key 'stop')AssertionError)r]   r   r#   r#   r$   set_model_kwargs  s    zDatabricks.set_model_kwargsrH   c                    s   d|kr.t |d r.t|d |dd|d< d|kr\t |d r\t|d |dd|d< t jf | | jd k	r| jd k	rtdn| jd k	rt	dt
 | jrt| j| j| j| j| jd| _n0| jr| jrt| j| j| j| jd| _ntd	d S )
Nr   ry   )rI   ry   r+   z.Cannot set both extra_params and extra_params.z<model_kwargs is deprecated. Please use extra_params instead.)rB   r   rC   rD   rG   )rB   r   rg   rh   zDMust specify either endpoint_name or cluster_id/cluster_driver_port.)rx   r   rO   rL   rM   r   r   r   warningswarnDeprecationWarningrC   rA   rB   r   rD   rG   r   rg   rh   rf   )r!   rI   rS   r#   r$   rM     sT    






zDatabricks.__init__c                 C   sb   | j | j| j| j| j| j| j| j| j| j	| j
| j| jdkr>dnt| j| jdkrTdnt| jdS )zReturn default params.N)rB   rC   rg   rh   rD   r   r   r   r   r   r   rG   r   r+   )rB   rC   rg   rh   rD   r   r   r   r   r   r   rG   r   r   r+   r/   r#   r#   r$   _default_params  s&    zDatabricks._default_paramsc                 C   s   | j S r,   )r   r/   r#   r#   r$   _identifying_params  s    zDatabricks._identifying_paramsc                 C   s   dS )zReturn type of llm.r   r#   r/   r#   r#   r$   	_llm_type  s    zDatabricks._llm_type)r   r   run_managerkwargsr   c                 K   sh   d|i}| j jr|| j || jp*| j || |rD||d< | jrV| jf |}| j j|| jdS )zAQueries the LLM endpoint with the given prompt and stop sequence.r   r   )r+   )	r   r0   updater   r   r   r   r-   r+   )r!   r   r   r   r   r   r#   r#   r$   _call  s    
zDatabricks._call)NN)4r1   r2   r3   r4   r   rs   rB   r5   r6   rt   r   rC   r	   rg   rh   r   r   r   r   r   r+   rD   r   floatr   r   r   r   r   dictr   rG   ry   r8   r   r   r   r   r7   r   r   r   r   r   rM   r   r   r   r   r
   r   re   r#   r#   rS   r$   r     sV   
1

	
 
 
*.  
)(rn   ru   r   abcr   r   typingr   r   r   r   r   r	   r   Zlangchain_core.callbacksr
   Zlangchain_core.language_modelsr   Zlangchain_core.pydantic_v1r   r   r   r   r   __all__r   r5   r;   r=   r@   rA   rf   rm   rs   rt   r8   rx   r   r   r   r#   r#   r#   r$   <module>   s2     H	 