"""
OpenAI provider implementation for Praval framework.
Provides integration with OpenAI's Chat Completions API with support
for conversation history, tool calling, and streaming responses.
"""
import os
from typing import List, Dict, Any, Optional
import openai
from ..core.exceptions import (
HITLConfigurationError,
InterventionRequired,
ProviderError,
)
from ..hitl.runtime import HITLRuntime
def _redact_secrets(message: str) -> str:
if not message:
return message
secrets = [
os.getenv("OPENAI_API_KEY"),
os.getenv("ANTHROPIC_API_KEY"),
os.getenv("COHERE_API_KEY"),
]
redacted = message
for secret in secrets:
if secret and secret in redacted:
redacted = redacted.replace(secret, "***")
return redacted
[docs]
class OpenAIProvider:
"""
OpenAI provider for LLM interactions.
Handles communication with OpenAI's GPT models through the
Chat Completions API with support for tools and conversation history.
"""
[docs]
def __init__(self, config):
"""
Initialize OpenAI provider.
Args:
config: AgentConfig object with provider settings
Raises:
ProviderError: If OpenAI client initialization fails
"""
self.config = config
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ProviderError("OPENAI_API_KEY environment variable not set")
self.client = openai.OpenAI(api_key=api_key)
except Exception as e:
raise ProviderError(
f"Failed to initialize OpenAI client: {_redact_secrets(str(e))}"
) from e
[docs]
def generate(
self,
messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
hitl_context: Optional[Dict[str, Any]] = None,
) -> str:
"""
Generate a response using OpenAI's Chat Completions API.
Args:
messages: Conversation history as list of message dictionaries
tools: Optional list of available tools for function calling
hitl_context: Optional run metadata for HITL gating/resume
Returns:
Generated response as a string
Raises:
ProviderError: If API call fails
"""
try:
call_params = {
"model": "gpt-3.5-turbo",
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
}
if tools:
formatted_tools = self._format_tools_for_openai(tools)
if formatted_tools:
call_params["tools"] = formatted_tools
call_params["tool_choice"] = "auto"
response = self.client.chat.completions.create(**call_params)
if response.choices and response.choices[0].message:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
return self._handle_tool_calls(
message.tool_calls,
tools,
messages,
hitl_context=hitl_context,
)
return message.content or ""
return ""
except (InterventionRequired, HITLConfigurationError):
raise
except Exception as e:
raise ProviderError(f"OpenAI API error: {_redact_secrets(str(e))}") from e
def _format_tools_for_openai(
self, tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Format tools for OpenAI's function calling format."""
formatted_tools = []
for tool in tools:
if "function" not in tool or "description" not in tool:
continue
formatted_tool = {
"type": "function",
"function": {
"name": tool["function"].__name__,
"description": tool["description"],
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
if "parameters" in tool:
for param_name, param_info in tool["parameters"].items():
python_type = param_info.get("type", "str")
json_type = self._python_type_to_json_schema(python_type)
formatted_tool["function"]["parameters"]["properties"][
param_name
] = {"type": json_type}
if param_info.get("required", False):
formatted_tool["function"]["parameters"]["required"].append(
param_name
)
formatted_tools.append(formatted_tool)
return formatted_tools
def _python_type_to_json_schema(self, python_type: str) -> str:
"""Convert Python type annotation to JSON schema type."""
type_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
"List": "array",
"Dict": "object",
}
return type_mapping.get(python_type, "string")
def _build_runtime(
self, hitl_context: Optional[Dict[str, Any]]
) -> Optional[HITLRuntime]:
if not hitl_context:
return None
run_id = hitl_context.get("run_id")
agent_name = hitl_context.get("agent_name")
provider_name = hitl_context.get("provider_name")
if not run_id or not agent_name or not provider_name:
return None
return HITLRuntime(
run_id=run_id,
agent_name=agent_name,
provider_name=provider_name,
hitl_enabled=bool(hitl_context.get("enabled", False)),
db_path=hitl_context.get("db_path"),
trace_id=hitl_context.get("trace_id"),
)
def _serialize_tool_calls(self, tool_calls: List[Any]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
call_type = tool_call.get("type")
tool_id = tool_call.get("id")
function_obj = tool_call.get("function", {})
function_name = function_obj.get("name")
arguments = function_obj.get("arguments", "{}")
else:
call_type = getattr(tool_call, "type", None)
tool_id = getattr(tool_call, "id", None)
function_obj = getattr(tool_call, "function", None)
function_name = getattr(function_obj, "name", None)
arguments = getattr(function_obj, "arguments", "{}")
if call_type == "function":
serialized.append(
{
"id": tool_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
return serialized
def _execute_tool_calls(
self,
*,
serialized_tool_calls: List[Dict[str, Any]],
available_tools: List[Dict[str, Any]],
original_messages: List[Dict[str, str]],
hitl_context: Optional[Dict[str, Any]],
start_index: int = 0,
existing_tool_messages: Optional[List[Dict[str, str]]] = None,
resume_intervention: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
runtime = self._build_runtime(hitl_context)
tool_messages = list(existing_tool_messages or [])
tool_map = {
tool["function"].__name__: tool
for tool in available_tools or []
if "function" in tool
}
for idx in range(start_index, len(serialized_tool_calls)):
tool_call = serialized_tool_calls[idx]
function_name = tool_call["function"]["name"]
arguments = tool_call["function"]["arguments"]
tool_call_id = tool_call["id"]
if (
idx == start_index
and resume_intervention is not None
and runtime is not None
):
result_content = runtime.execute_with_decision(
intervention=resume_intervention,
available_tools=available_tools or [],
)
elif runtime is not None:
continuation_state = {
"schema": "openai_tool_v1",
"original_messages": original_messages,
"tool_calls": serialized_tool_calls,
"current_index": idx,
"tool_messages": list(tool_messages),
}
result_content = runtime.execute_or_interrupt(
tool_call_id=tool_call_id,
function_name=function_name,
raw_args=arguments,
available_tools=available_tools or [],
continuation_state=continuation_state,
)
else:
tool_def = tool_map.get(function_name)
if tool_def is None:
result_content = f"Unknown function: {function_name}"
else:
try:
import json
args = json.loads(arguments)
result_content = str(tool_def["function"](**args))
except Exception as e:
result_content = f"Error: {str(e)}"
tool_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": result_content,
}
)
return tool_messages
def _follow_up_response(
self,
*,
serialized_tool_calls: List[Dict[str, Any]],
original_messages: List[Dict[str, str]],
tool_messages: List[Dict[str, str]],
) -> str:
extended_messages = list(original_messages)
extended_messages.append(
{
"role": "assistant",
"tool_calls": serialized_tool_calls,
}
)
extended_messages.extend(tool_messages)
try:
follow_up_response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=extended_messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
)
if follow_up_response.choices and follow_up_response.choices[0].message:
return follow_up_response.choices[0].message.content or ""
return "No response generated after tool execution"
except Exception:
return "\n".join([msg["content"] for msg in tool_messages])
def _handle_tool_calls(
self,
tool_calls: List[Any],
available_tools: List[Dict[str, Any]],
original_messages: List[Dict[str, str]],
hitl_context: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle OpenAI tool/function calls with optional HITL interception."""
serialized_tool_calls = self._serialize_tool_calls(tool_calls)
tool_messages = self._execute_tool_calls(
serialized_tool_calls=serialized_tool_calls,
available_tools=available_tools or [],
original_messages=original_messages,
hitl_context=hitl_context,
)
return self._follow_up_response(
serialized_tool_calls=serialized_tool_calls,
original_messages=original_messages,
tool_messages=tool_messages,
)