Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OLS-1379: Add oc tools #2216

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def conversation_request(
processed_request.conversation_id,
llm_request,
processed_request.previous_input,
streaming=False,
user_token=processed_request.user_token,
)

processed_request.timestamps["generate response"] = time.time()
Expand Down Expand Up @@ -361,6 +363,7 @@ def generate_response(
llm_request: LLMRequest,
previous_input: list[CacheEntry],
streaming: bool = False,
user_token: Optional[str] = None,
) -> Union[SummarizerResponse, Generator]:
"""Generate response based on validation result, previous input, and model output.

Expand All @@ -369,6 +372,7 @@ def generate_response(
llm_request: The request containing a query.
previous_input: The history of the conversation (if available).
streaming: The flag indicating if the response should be streamed.
user_token: The user token used for authorization.

Returns:
SummarizerResponse or Generator, depending on the streaming flag.
Expand All @@ -386,7 +390,7 @@ def generate_response(
llm_request.query, config.rag_index, history
)
response = docs_summarizer.create_response(
llm_request.query, config.rag_index, history
llm_request.query, config.rag_index, history, user_token
)
logger.debug("%s Generated response: %s", conversation_id, response)
return response
Expand Down
108 changes: 59 additions & 49 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Any, AsyncGenerator, Optional

from langchain.globals import set_debug
from langchain.llms.base import LLM
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables.base import RunnableBinding
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools.base import BaseTool
from llama_index.core import VectorStoreIndex

from ols import config
Expand All @@ -16,7 +15,8 @@
from ols.constants import MAX_ITERATIONS, RAG_CONTENT_LIMIT, GenericLLMParameters
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.query_helpers.query_helper import QueryHelper
from ols.src.tools.tools import tools_map
from ols.src.tools.oc_cli import log_to_oc
from ols.src.tools.tools import default_tools, execute_tool_calls, oc_tools
from ols.utils.token_handler import TokenHandler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,37 +131,70 @@ def _prepare_prompt(

def _invoke_llm(
self,
llm: LLM | RunnableBinding,
prompt: PromptTemplate | ChatPromptTemplate,
prompt_input: dict,
messages: list[ChatPromptTemplate],
llm_input_values: dict[str, str],
tools_map: Optional[dict] = None,
is_final_round: bool = False,
) -> tuple[AIMessage, TokenCounter]:
"""Invoke LLM to get response."""
"""Invoke LLM with optional tools."""
llm = (
self.bare_llm
if is_final_round or not tools_map
else self.bare_llm.bind_tools(tools_map.values())
)

with TokenMetricUpdater(
llm=self.bare_llm,
provider=self.provider_config.type,
model=self.model,
) as generic_token_counter:
# Create chain using runnables
chain = prompt | llm
# Get model response
# langchain magic for chaining
chain = messages | llm # type: ignore
out = chain.invoke(
input=prompt_input,
input=llm_input_values,
config={"callbacks": [generic_token_counter]},
)
return out, generic_token_counter.token_counter

def _get_available_tools(self, user_token: Optional[str]) -> dict[str, BaseTool]:
"""Get available tools based on introspection and user token."""
if not self._introspection_enabled:
return {}

logger.info("Introspection enabled - using default tools selection")

tools_map = default_tools

if user_token and log_to_oc(user_token):
logger.info(
"Succesfully authenticated to 'oc' CLI with user token "
"- adding 'oc' tools"
)
# TODO: when we are adding additional tools, ensure we are
# not overwriting the existing tools - currently depends on
# the tool name
tools_map = {**tools_map, **oc_tools}

return tools_map

def create_response(
self,
query: str,
vector_index: Optional[VectorStoreIndex] = None,
history: Optional[list[str]] = None,
user_token: Optional[str] = None,
) -> SummarizerResponse:
"""Create a response for the given query based on the provided conversation context."""
final_prompt, llm_input_values, rag_chunks, truncated = self._prepare_prompt(
query, vector_index, history
)

messages = final_prompt.copy()
messages = final_prompt.model_copy()

# TODO: for the specific tools type (oc) we need specific additional
# context (user_token) to get the tools, we need to think how to make
# it more generic to avoid low-level code changes with new tools type
tools_map = self._get_available_tools(user_token)

# TODO: Tune system prompt
# TODO: Handle context for each iteration
Expand All @@ -170,51 +203,28 @@ def create_response(
# TODO: Add tool info to transcript
for i in range(MAX_ITERATIONS):

# Force model to give final response (by not sending any tool),
# when introspection is disabled or max iteration is reached.
if (not self._introspection_enabled) or (i == MAX_ITERATIONS - 1):
# TODO: Modify the sys instruction when max iteration is reached.
out, token_counter = self._invoke_llm(
self.bare_llm, messages, llm_input_values
)
response = out.content
break

llm_with_tools = self.bare_llm.bind_tools(tools_map.values())
# Force llm to give final response when introspection is disabled
# or max iteration is reached
is_final_round = (not self._introspection_enabled) or (
i == MAX_ITERATIONS - 1
)
out, token_counter = self._invoke_llm(
llm_with_tools, messages, llm_input_values
messages, llm_input_values, tools_map, is_final_round
)

# Check if model is ready with final response
# if (not ai_msg.tool_calls) and (ai_msg.content):
if out.response_metadata["finish_reason"] == "stop":
if is_final_round or out.response_metadata["finish_reason"] == "stop":
response = out.content
break

# Before we can add tool output to messages,
# we need to add complete model response
# Before we can add tool output to messages, we need to add
# complete model response
messages.append(out)

# TODO: Parallelization of tool execution
# Iterate through tool calls, Model may support parallel tool calls
for tool_call in out.tool_calls:

# Fetch tool name and required args
tool_name = tool_call["name"].lower()
tool_args = tool_call["args"]
# Execute tool/function
try:
tool_output = tools_map[tool_name].invoke(tool_args)
except Exception:
tool_output = f"error while executing {tool_name}"

logger.debug(
"tool name: %s\ntool args: %s\ntool_output: %s",
tool_name,
tool_args,
tool_output,
)
messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))
# TODO: explicit check for {"finish_reson": "tool_call"}?
tool_calls_messages = execute_tool_calls(tools_map, out.tool_calls)
messages.extend(tool_calls_messages)

return SummarizerResponse(response, rag_chunks, truncated, token_counter)

Expand Down
2 changes: 1 addition & 1 deletion ols/src/rag_index/index_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# type: ignore # noqa: PGH003
# type: ignore
"""Module for loading index."""

import logging
Expand Down
Loading