Source code for praval.providers.cohere

"""
Cohere provider implementation for Praval framework.

Provides integration with Cohere's chat models through their
Chat API with support for conversation history.
"""

import os
from typing import List, Dict, Any, Optional

import cohere
from ..core.exceptions import ProviderError


[docs] class CohereProvider: """ Cohere provider for LLM interactions. Handles communication with Cohere's chat models through the Chat API with conversation history support. """
[docs] def __init__(self, config): """ Initialize Cohere provider. Args: config: AgentConfig object with provider settings Raises: ProviderError: If Cohere client initialization fails """ self.config = config try: api_key = os.getenv("COHERE_API_KEY") if not api_key: raise ProviderError("COHERE_API_KEY environment variable not set") self.client = cohere.Client(api_key) except Exception as e: raise ProviderError(f"Failed to initialize Cohere client: {str(e)}") from e
[docs] def generate( self, messages: List[Dict[str, str]], tools: Optional[List[Dict[str, Any]]] = None ) -> str: """ Generate a response using Cohere's Chat API. Args: messages: Conversation history as list of message dictionaries tools: Optional list of available tools (not fully supported yet) Returns: Generated response as a string Raises: ProviderError: If API call fails """ try: # Extract the current user message and chat history current_message, chat_history = self._prepare_chat_format(messages) # Prepare the API call parameters call_params = { "message": current_message, "temperature": self.config.temperature, "max_tokens": self.config.max_tokens } # Add chat history if available if chat_history: call_params["chat_history"] = chat_history # Add system message as preamble if present system_message = self._extract_system_message(messages) if system_message: call_params["preamble"] = system_message # Make the API call response = self.client.chat(**call_params) # Extract the response text return response.text if hasattr(response, 'text') else "" except Exception as e: raise ProviderError(f"Cohere API error: {str(e)}") from e
def _prepare_chat_format(self, messages: List[Dict[str, str]]) -> tuple[str, List[Dict[str, str]]]: """ Prepare messages in Cohere's chat format. Cohere expects the current user message separately from chat history. Args: messages: List of conversation messages Returns: Tuple of (current_message, chat_history) """ # Filter out system messages for conversation conversation_messages = [ msg for msg in messages if msg.get("role") in ["user", "assistant"] ] if not conversation_messages: return "", [] # The last message should be the current user message current_message = "" chat_history = [] if conversation_messages: # Get the last user message as current message last_message = conversation_messages[-1] if last_message.get("role") == "user": current_message = last_message.get("content", "") # Convert previous messages to chat history format for i, msg in enumerate(conversation_messages[:-1]): role = msg.get("role", "") content = msg.get("content", "") if role == "user": chat_history.append({"role": "USER", "message": content}) elif role == "assistant": chat_history.append({"role": "CHATBOT", "message": content}) else: # If last message is not from user, treat it as continuation current_message = "Please continue." # Convert all messages to chat history for msg in conversation_messages: role = msg.get("role", "") content = msg.get("content", "") if role == "user": chat_history.append({"role": "USER", "message": content}) elif role == "assistant": chat_history.append({"role": "CHATBOT", "message": content}) return current_message, chat_history def _extract_system_message(self, messages: List[Dict[str, str]]) -> Optional[str]: """ Extract system message from conversation messages. Args: messages: List of conversation messages Returns: System message content if found, None otherwise """ for message in messages: if message.get("role") == "system": return message.get("content", "") return None