HEX
Server: Apache
System: Linux scp1.abinfocom.com 5.4.0-216-generic #236-Ubuntu SMP Fri Apr 11 19:53:21 UTC 2025 x86_64
User: confeduphaar (1010)
PHP: 8.1.33
Disabled: exec,passthru,shell_exec,system
Upload Files
File: //lib/mysqlsh/lib/python3.8/site-packages/oci/addons/adk/agent_client.py
# coding: utf-8
# Copyright (c) 2016, 2025, Oracle and/or its affiliates.  All rights reserved.
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.

"""
ADK Agent Client
"""

from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import oci
from oci.constants import HEADER_REQUEST_ID
from oci.exceptions import ClientError, ServiceError
from oci.generative_ai_agent import GenerativeAiAgentClient
from oci.generative_ai_agent.models import (
    CreateToolDetails,
    Function,
    FunctionCallingToolConfig,
    KnowledgeBaseConfig,
    LlmConfig,
    LlmCustomization,
    RagToolConfig,
    UpdateAgentDetails,
    UpdateToolDetails,
)
from oci.generative_ai_agent_runtime import GenerativeAiAgentRuntimeClient
from oci.generative_ai_agent_runtime.models import (
    ChatDetails,
    ChatResult,
    CreateSessionDetails,
    FunctionCallingPerformedAction,
    Session,
)
from oci.response import Response
from oci.util import formatted_flat_dict, to_dict
from oci.pagination import list_call_get_all_results

from oci.addons.adk.agent_error import AgentError, UserError
from oci.addons.adk.auth.factory import AuthFactory, OCIAuthType
from oci.addons.adk.constants import (
    DEFAULT_SESSION_NAME,
    DEFAULT_TIMEOUT,
    DESC_SUFFIX,
    FREEFORM_TAGS,
)
from oci.addons.adk.logger import default_logger as logger
from oci.addons.adk.run.types import PerformedAction
from oci.addons.adk.tool import FunctionTool
from oci.addons.adk.tool.prebuilt import AgenticRagTool
from oci.addons.adk.util import build_custom_function_params, get_region_endpoint


class AgentClient:
    """Client for interacting with OCI Agent endpoints."""

    def __init__(
        self,
        auth_type: Optional[str] = "api_key",
        config: Optional[str] = oci.config.DEFAULT_LOCATION,
        profile: Optional[str] = oci.config.DEFAULT_PROFILE,
        token: Optional[str] = None,
        region: Optional[str] = None,
        runtime_endpoint: Optional[str] = None,
        management_endpoint: Optional[str] = None,
        timeout: Optional[float | Tuple[float, float]] = DEFAULT_TIMEOUT,
        debug: Optional[bool] = False,
    ) -> None:
        """
        Initialize the agent client.

        Args:
            auth_type: The authentication type, Must be
                one of: 'api_key', 'security_token', 'instance_principal',
                'resource_principal', 'instance_obo_user'
            config: The config file path for the agent client
            profile: The profile to use for authentication
            token: The token to use for authentication
            region: The region to use for authentication
                If provided, runtime_endpoint and management_endpoint will be
                constructed automatically unless explicitly specified
            runtime_endpoint: The runtime endpoint for OCI Generative AI Agent
            management_endpoint: The management endpoint for OCI Generative AI Agent
            timeout: The timeout for the agent client
            debug: Whether to enable debug logging

        Raises:
            AgentError: If the agent endpoint id is unknown
        """
        # use auth module
        self.auth = AuthFactory.create_oci_auth(
            auth_type=OCIAuthType(auth_type),
            config_path=config,
            profile=profile,
            token=token,
            region=region,
        )

        # endpoint
        self.runtime_endpoint = self.get_runtime_endpoint(runtime_endpoint, region)
        self.management_endpoint = self.get_management_endpoint(
            management_endpoint, region
        )

        # init Agent Runtime Client
        self._rt_client = GenerativeAiAgentRuntimeClient(
            config=self.auth.get_config(),
            signer=self.auth.get_auth_credentials(),
            service_endpoint=self.runtime_endpoint,
            timeout=timeout,
            retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY,
        )
        # init Agent client
        self._mgmt_client = GenerativeAiAgentClient(
            config=self.auth.get_config(),
            signer=self.auth.get_auth_credentials(),
            service_endpoint=self.management_endpoint,
            timeout=timeout,
            retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY,
        )
        self.debug = debug
        self._handlers: Dict[str, Callable] = {}

    def create_session(
        self,
        agent_endpoint_id: str,
        display_name: str | None = None,
        description: str | None = None,
    ) -> str:
        """
        Create a new agent session.

        Args:
            agent_endpoint_id: The agent endpoint id
            display_name: Name for the session
            description: Description for the session

        Returns:
            The session ID

        Raises:
            AgentError: If session creation fails
        """
        create_session_details = CreateSessionDetails()
        create_session_details.display_name = (
            display_name if display_name else DEFAULT_SESSION_NAME
        )
        create_session_details.description = (
            description if description else DEFAULT_SESSION_NAME + DESC_SUFFIX
        )

        logger.debug(f"Creating session with details: {create_session_details}")

        create_session_response: Response = self._invoke(
            method=self._rt_client.create_session,
            success_message="Session creation succeeded",
            error_message="Failed to create session",
            debug=self.debug,
            create_session_details=create_session_details,
            agent_endpoint_id=agent_endpoint_id,
        )
        response_data: Session = create_session_response.data
        return response_data.id

    def delete_session(self, agent_endpoint_id: str, session_id: str) -> None:
        """
        Delete an existing agent session.

        Args:
            agent_endpoint_id: The agent endpoint id
            session_id: The session ID to delete

        Returns:
            None

        Raises:
            AgentError: If session deletion fails
        """
        logger.debug(f"Deleting session: {session_id}")
        self._invoke(
            method=self._rt_client.delete_session,
            success_message="Session deletion succeeded",
            error_message="Failed to delete session",
            debug=self.debug,
            session_id=session_id,
            agent_endpoint_id=agent_endpoint_id,
        )
        return

    def chat(
        self,
        agent_endpoint_id: str,
        session_id: str,
        user_message: Optional[str] = None,
        performed_actions: Optional[List[PerformedAction]] = None,
    ) -> Dict[str, Any]:
        """
        Invoke a chat session with the agent.

        Args:
            agent_endpoint_id: The agent endpoint id
            session_id: The session ID to use for the chat
            user_message: The message from the user
            performed_actions: Actions that have been performed

        Returns:
            The response from the agent

        Raises:
            AgentError: If the invocation fails
        """
        chat_details = ChatDetails()
        fc_performed_actions = []
        if performed_actions:
            fc_performed_actions = [
                FunctionCallingPerformedAction(
                    action_id=action.action_id,
                    performed_action_type=action.performed_action_type,
                    function_call_output=action.function_call_output,
                )
                for action in performed_actions
            ]
        chat_details.performed_actions = fc_performed_actions
        chat_details.user_message = user_message
        chat_details.should_stream = False
        chat_details.session_id = session_id

        logger.debug(f"Invoking chat endpoint with data: {chat_details}")

        chat_response: Response = self._invoke(
            method=self._rt_client.chat,
            success_message="Chat succeeded",
            error_message="Failed to invoke chat endpoint",
            debug=self.debug,
            chat_details=chat_details,
            agent_endpoint_id=agent_endpoint_id,
        )

        response_data: ChatResult = chat_response.data
        return to_dict(response_data)

    def update_agent(
        self,
        agent_id: str,
        name: str | None = None,
        description: str | None = None,
        instructions: str | None = None,
    ) -> None:
        """
        Update the agent.

        Args:
            agent_id: The ID of the agent to update
            name: The new name for the agent
            description: The new description for the agent
            instructions: The new instructions for the agent

        Returns:
            None

        Raises:
            AgentError: If updating the instructions fails
        """
        update_agent_details = UpdateAgentDetails(freeform_tags=FREEFORM_TAGS)
        if name is not None:
            update_agent_details.display_name = name
        if description is not None:
            update_agent_details.description = description
        if instructions is not None:
            llm_customization = LlmCustomization(
                instruction=instructions,
            )
            update_agent_details.llm_config = LlmConfig(
                routing_llm_customization=llm_customization,
            )
        self._invoke(
            method=self._mgmt_client.update_agent,
            success_message="Agent updated successfully",
            error_message="Failed to update agent",
            debug=self.debug,
            update_agent_details=update_agent_details,
            agent_id=agent_id,
        )

    def find_tools(self, compartment_id: str, agent_id: str) -> List[Dict[str, Any]]:
        """
        Find tools associated with the agent.

        Args:
            compartment_id: The compartment ID
            agent_id: The agent ID

        Returns:
            List of tools

        Raises:
            AgentError: If finding tools fails
        """
        # use pagination to eagerly get all records
        find_tools_response: Response = self._invoke(
            method=list_call_get_all_results,
            success_message="Tools retrieval succeeded",
            error_message="Failed to find tools",
            debug=self.debug,
            list_func_ref=self._mgmt_client.list_tools,
            compartment_id=compartment_id,
            agent_id=agent_id,
        )

        return [to_dict(tool) for tool in find_tools_response.data]

    def delete_tool(self, tool_id: str) -> None:
        """
        Delete a tool by ID.

        Args:
            tool_id: The ID of the tool to delete

        Returns:
            None

        Raises:
            AgentError: If tool deletion fails
        """
        self._invoke(
            method=self._mgmt_client.delete_tool,
            success_message="Tool deletion succeeded",
            error_message="Failed to delete tool",
            debug=self.debug,
            tool_id=tool_id,
        )

    def add_function_tool(
        self,
        function: FunctionTool,
        compartment_id: str,
        agent_id: str,
    ) -> Dict[str, Any]:
        """
        Add a function tool to the agent.

        Args:
            function: The function tool to add

        Returns:
            Dict[str, Any]: The added tool as a dict with its attributes

        Raises:
            AgentError: If adding the tool fails
        """
        tool_config = FunctionCallingToolConfig(
            function=Function(
                name=function.name,
                description=function.description,
                parameters=build_custom_function_params(function.parameters),
            )
        )
        create_tool_details = CreateToolDetails(
            display_name=function.name,
            description=function.description,
            compartment_id=compartment_id,
            agent_id=agent_id,
            tool_config=tool_config,
            freeform_tags=FREEFORM_TAGS,
        )

        add_tool_response: Response = self._invoke(
            method=self._mgmt_client.create_tool,
            success_message="Tool addition succeeded",
            error_message="Failed to add tool",
            debug=self.debug,
            create_tool_details=create_tool_details,
        )

        return to_dict(add_tool_response.data)

    def get_tool(self, tool_id: str) -> Dict[str, Any]:
        """
        Get a tool by ID.

        Args:
            tool_id: The ID of the tool to get

        Returns:
            Dict[str, Any]: The tool as a dict with its attributes

        Raises:
            AgentError: If getting the tool fails
        """
        get_tool_response: Response = self._invoke(
            method=self._mgmt_client.get_tool,
            success_message="Tool retrieval succeeded",
            error_message="Failed to get tool",
            debug=self.debug,
            tool_id=tool_id,
        )

        return to_dict(get_tool_response.data)

    def add_rag_tool(
        self,
        rag_tool: AgenticRagTool,
        compartment_id: str,
        agent_id: str,
    ) -> Dict[str, Any]:
        """
        Add an agentic RAG tool to the agent.

        Args:
            rag_tool: The RAG tool to add
            compartment_id: The compartment ID
            agent_id: The agent ID

        Returns:
            Dict[str, Any]: The added tool as a dict with its attributes

        Raises:
            AgentError: If adding the tool fails
        """
        tool_config = RagToolConfig(
            knowledge_base_configs=[
                KnowledgeBaseConfig(knowledge_base_id=kb_id)
                for kb_id in rag_tool.knowledge_base_ids
            ]
        )
        create_tool_details = CreateToolDetails(
            compartment_id=compartment_id,
            agent_id=agent_id,
            display_name=rag_tool.name,
            description=rag_tool.description,
            tool_config=tool_config,
            freeform_tags=FREEFORM_TAGS,
        )

        add_tool_response: Response = self._invoke(
            method=self._mgmt_client.create_tool,
            success_message="RAG tool addition succeeded",
            error_message="""
            Failed to add RAG tool
            probably because a remote RAG tool not configured by ADK already exists
            """,
            debug=self.debug,
            create_tool_details=create_tool_details,
        )

        return to_dict(add_tool_response.data)

    def update_function_tool(
        self,
        tool_id: str,
        function: FunctionTool,
    ) -> None:
        """
        Update a function tool by ID.

        Args:
            tool_id: The ID of the tool to update
            function: The function tool to update

        Returns:
            None

        Raises:
            AgentError: If updating the tool fails
        """
        tool_config = FunctionCallingToolConfig(
            function=Function(
                name=function.name,
                description=function.description,
                parameters=build_custom_function_params(function.parameters),
            )
        )
        update_tool_details = UpdateToolDetails(
            display_name=function.name,
            description=function.description,
            tool_config=tool_config,
            freeform_tags=FREEFORM_TAGS,
        )

        self._invoke(
            method=self._mgmt_client.update_tool,
            success_message="Tool update succeeded",
            error_message="Failed to update tool",
            debug=self.debug,
            tool_id=tool_id,
            update_tool_details=update_tool_details,
        )

    def update_rag_tool(
        self,
        tool_id: str,
        rag_tool: AgenticRagTool,
    ) -> None:
        """
        Update a rag tool by ID.

        Args:
            tool_id: The ID of the tool to update
            rag_tool: The rag tool to update

        Returns:
            None

        Raises:
            AgentError: If updating the tool fails
        """
        tool_config = RagToolConfig(
            knowledge_base_configs=[
                KnowledgeBaseConfig(knowledge_base_id=kb_id)
                for kb_id in rag_tool.knowledge_base_ids
            ]
        )
        update_tool_details = UpdateToolDetails(
            display_name=rag_tool.name,
            description=rag_tool.description,
            tool_config=tool_config,
            freeform_tags=FREEFORM_TAGS,
        )

        self._invoke(
            method=self._mgmt_client.update_tool,
            success_message="RAG tool update succeeded",
            error_message="Failed to update RAG tool",
            debug=self.debug,
            tool_id=tool_id,
            update_tool_details=update_tool_details,
        )

    def get_agent(self, agent_id: str) -> Dict[str, Any]:
        """
        Get the agent details.

        Args:
            agent_id: The ID of the agent

        Returns:
            The agent details

        Raises:
            AgentError: If getting the agent fails
        """

        get_agent_response: Response = self._invoke(
            method=self._mgmt_client.get_agent,
            success_message="Agent retrieval succeeded",
            error_message="Failed to get agent",
            debug=self.debug,
            agent_id=agent_id,
        )

        return to_dict(get_agent_response.data)

    def get_agent_endpoint_details(self, agent_endpoint_id: str) -> Dict[str, Any]:
        """
        Get the agent endpoint details.

        Args:
            agent_endpoint_id: The agent endpoint id

        Returns:
            The agent endpoint details

        Raises:
            AgentError: If getting the agent endpoint fails
        """
        get_agent_endpoint_response: Response = self._invoke(
            method=self._mgmt_client.get_agent_endpoint,
            success_message="Agent endpoint retrieval succeeded",
            error_message="Failed to get agent endpoint",
            debug=self.debug,
            agent_endpoint_id=agent_endpoint_id,
        )

        return to_dict(get_agent_endpoint_response.data)

    def get_runtime_endpoint(
        self, endpoint_url: Optional[str], region: Optional[str]
    ) -> str:
        """
        Get the runtime endpoint.

        Args:
            endpoint_url: The runtime endpoint
            region: The region

        Returns:
            The runtime endpoint
        """
        return self._get_service_endpoint(endpoint_url, region, "agent-runtime")

    def get_management_endpoint(
        self, endpoint_url: Optional[str], region: Optional[str]
    ) -> str:
        """
        Get the management endpoint.

        Args:
            endpoint_url: The management endpoint
            region: The region

        Returns:
            The management endpoint
        """
        return self._get_service_endpoint(endpoint_url, region, "agent")

    def _get_service_endpoint(
        self,
        endpoint_url: Optional[str] = None,
        region: Optional[str] = None,
        endpoint_type: Literal["agent-runtime", "agent"] = "agent-runtime",
    ) -> str:
        """
        Set the service endpoint.

        Args:
            endpoint_url: The service endpoint
            region: The region
            endpoint_type: The endpoint type

        Returns:
            The service endpoint
        """
        # use endpoint if provided, otherwise use region
        if endpoint_url:
            return endpoint_url
        elif region:
            return get_region_endpoint(region, endpoint_type)
        else:
            raise UserError(
                message="At least one of endpoint_url or region must be provided"
            )

    def _invoke(
        self,
        method: Callable,
        success_message: str,
        error_message: str,
        debug: Optional[bool],
        **kwargs,
    ) -> Response:
        """
        Invoke a method.

        Args:
            method: The method to invoke
            success_message: The success message for debug purposes
            error_message: The error message
            debug: Whether to enable debug logging
            **kwargs: Additional keyword arguments to pass in method

        Returns:
            The response
        """
        response: Response | None = None
        try:
            response = method(**kwargs)
            if response is None:
                raise UserError(
                    message=error_message,
                    response=response,
                )

            if response.status not in (200, 202):
                raise AgentError(message=error_message, response=response)

            logger.debug(
                f"{success_message} response: {formatted_flat_dict(response.data)}"
            )
            logger.debug(f"{success_message} {HEADER_REQUEST_ID}: {response.request_id}")

            return response

        except ServiceError as e:
            raise AgentError(
                message=f"{error_message}: {str(e)}",
                response=response,
            ) from e

        except (ValueError, ClientError) as e:
            raise UserError(
                message=f"{error_message}: {str(e)}",
                response=response,
            ) from e