crewplus
Advanced tools
| """ | ||
| Capabilities module for custom tool support. | ||
| Provides a registry-based system for adding capabilities (web search, bash, etc.) | ||
| to chat models that don't have native support for these features. | ||
| """ | ||
| from .base import BaseCapability, CapabilityConfig | ||
| from .registry import CapabilityRegistry | ||
| from .web_search import TavilySearchCapability | ||
| from .bash import BashCapability | ||
| __all__ = [ | ||
| "BaseCapability", | ||
| "CapabilityConfig", | ||
| "CapabilityRegistry", | ||
| "TavilySearchCapability", | ||
| "BashCapability", | ||
| ] |
| """ | ||
| Base classes for capability system. | ||
| """ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, Optional, List | ||
| from pydantic import BaseModel, Field | ||
| class CapabilityConfig(BaseModel): | ||
| """Configuration for a capability.""" | ||
| enabled: bool = Field(default=True, description="Whether this capability is enabled") | ||
| type: str = Field(default="custom", description="Capability type: 'native' or 'custom'") | ||
| class Config: | ||
| extra = "allow" # Allow additional fields for specific capabilities | ||
| class BaseCapability(ABC): | ||
| """ | ||
| Abstract base class for all capabilities. | ||
| Capabilities are tools that can be added to chat models via function calling. | ||
| Each capability provides: | ||
| - A tool schema for the LLM to understand how to call it | ||
| - An optional execute method for actual execution (if needed) | ||
| Example: | ||
| class MyCapability(BaseCapability): | ||
| @property | ||
| def name(self) -> str: | ||
| return "my_tool" | ||
| @property | ||
| def description(self) -> str: | ||
| return "Does something useful" | ||
| def get_parameters_schema(self) -> Dict: | ||
| return { | ||
| "type": "object", | ||
| "properties": { | ||
| "input": {"type": "string", "description": "The input"} | ||
| }, | ||
| "required": ["input"] | ||
| } | ||
| """ | ||
| def __init__(self, config: Dict[str, Any]): | ||
| """ | ||
| Initialize the capability with configuration. | ||
| Args: | ||
| config: Configuration dictionary for this capability | ||
| """ | ||
| self.config = CapabilityConfig(**config) if isinstance(config, dict) else config | ||
| self._raw_config = config | ||
| @property | ||
| @abstractmethod | ||
| def name(self) -> str: | ||
| """ | ||
| Return the tool name for function calling. | ||
| Must match regex: ^[a-zA-Z0-9_-]{1,64}$ | ||
| """ | ||
| pass | ||
| @property | ||
| @abstractmethod | ||
| def description(self) -> str: | ||
| """ | ||
| Return a detailed description of what this tool does. | ||
| This helps the LLM understand when and how to use the tool. | ||
| """ | ||
| pass | ||
| @abstractmethod | ||
| def get_parameters_schema(self) -> Dict[str, Any]: | ||
| """ | ||
| Return the JSON Schema for the tool's parameters. | ||
| Returns: | ||
| Dict with "type": "object", "properties", and "required" keys | ||
| """ | ||
| pass | ||
| def get_tool_schema(self) -> Dict[str, Any]: | ||
| """ | ||
| Return the complete OpenAI-compatible tool schema. | ||
| Returns: | ||
| Tool definition in OpenAI function calling format | ||
| """ | ||
| return { | ||
| "type": "function", | ||
| "function": { | ||
| "name": self.name, | ||
| "description": self.description, | ||
| "parameters": self.get_parameters_schema() | ||
| } | ||
| } | ||
| def get_langchain_tool_schema(self) -> Dict[str, Any]: | ||
| """ | ||
| Return tool schema in LangChain/Claude format. | ||
| Returns: | ||
| Tool definition compatible with Claude's API | ||
| """ | ||
| return { | ||
| "name": self.name, | ||
| "description": self.description, | ||
| "input_schema": self.get_parameters_schema() | ||
| } | ||
| async def execute(self, **kwargs) -> str: | ||
| """ | ||
| Execute the capability with given arguments. | ||
| Override this method to provide actual execution logic. | ||
| Default implementation raises NotImplementedError. | ||
| Args: | ||
| **kwargs: Arguments matching the parameters schema | ||
| Returns: | ||
| String result of the execution | ||
| """ | ||
| raise NotImplementedError( | ||
| f"Capability '{self.name}' does not support execution. " | ||
| "Override execute() method to add execution support." | ||
| ) | ||
| def execute_sync(self, **kwargs) -> str: | ||
| """ | ||
| Synchronous version of execute. | ||
| Default implementation raises NotImplementedError. | ||
| """ | ||
| raise NotImplementedError( | ||
| f"Capability '{self.name}' does not support synchronous execution. " | ||
| "Override execute_sync() method to add execution support." | ||
| ) | ||
| def is_enabled(self) -> bool: | ||
| """Check if this capability is enabled.""" | ||
| return self.config.enabled | ||
| def get_config_value(self, key: str, default: Any = None) -> Any: | ||
| """Get a configuration value by key.""" | ||
| return self._raw_config.get(key, default) | ||
| def update_config(self, updates: Dict[str, Any]) -> None: | ||
| """ | ||
| Update configuration with new values. | ||
| Args: | ||
| updates: Dictionary of config updates to apply | ||
| """ | ||
| self._raw_config.update(updates) | ||
| self.config = CapabilityConfig(**self._raw_config) | ||
| def clone_with_config(self, config_override: Dict[str, Any]) -> "BaseCapability": | ||
| """ | ||
| Create a new instance with merged configuration. | ||
| Args: | ||
| config_override: Configuration values to override | ||
| Returns: | ||
| New capability instance with merged config | ||
| """ | ||
| merged_config = {**self._raw_config, **config_override} | ||
| return self.__class__(merged_config) | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(name='{self.name}', enabled={self.is_enabled()})" |
| """ | ||
| Bash command capability for function calling simulation. | ||
| This capability allows the LLM to express bash commands via function calling. | ||
| It does NOT actually execute commands - it's designed for testing and simulation | ||
| of the function calling workflow. | ||
| """ | ||
| import logging | ||
| from typing import Any, Dict, List, Optional | ||
| from .base import BaseCapability | ||
| class BashCapability(BaseCapability): | ||
| """ | ||
| Bash command capability for function calling simulation. | ||
| This capability provides the tool schema for bash commands, allowing the LLM | ||
| to generate proper function calls with command and restart parameters. | ||
| Execution is NOT performed - this is for function calling simulation only. | ||
| Configuration Options: | ||
| enabled (bool): Whether this capability is enabled (default: True) | ||
| safe_mode (bool): If True, hints to LLM to avoid dangerous commands (default: True) | ||
| allowed_commands (list): List of allowed command prefixes (optional) | ||
| timeout (int): Suggested timeout in seconds (default: 30) | ||
| working_directory (str): Suggested working directory (optional) | ||
| Example: | ||
| config = { | ||
| "enabled": True, | ||
| "safe_mode": True, | ||
| "allowed_commands": ["ls", "cat", "grep", "find", "echo"], | ||
| "timeout": 30 | ||
| } | ||
| capability = BashCapability(config) | ||
| schema = capability.get_tool_schema() | ||
| Note: | ||
| This capability is designed for testing the function calling workflow. | ||
| It does NOT execute actual bash commands. For actual execution, the | ||
| caller must handle the tool_calls returned by the LLM and execute | ||
| them in a controlled environment. | ||
| """ | ||
| def __init__(self, config: Dict[str, Any]): | ||
| super().__init__(config) | ||
| self.logger = logging.getLogger(__name__) | ||
| # Configuration defaults | ||
| self._safe_mode = config.get("safe_mode", True) | ||
| self._allowed_commands = config.get("allowed_commands", []) | ||
| self._timeout = config.get("timeout", 30) | ||
| self._working_directory = config.get("working_directory") | ||
| @property | ||
| def name(self) -> str: | ||
| return "bash" | ||
| @property | ||
| def description(self) -> str: | ||
| base_desc = ( | ||
| "Execute a bash command in a persistent shell session. " | ||
| "Use this for running system commands, scripts, and command-line operations. " | ||
| "The shell session maintains state between commands (working directory, " | ||
| "environment variables, created files)." | ||
| ) | ||
| if self._safe_mode: | ||
| base_desc += ( | ||
| " IMPORTANT: Avoid destructive commands like 'rm -rf', 'format', " | ||
| "or commands that could harm the system." | ||
| ) | ||
| if self._allowed_commands: | ||
| allowed_str = ", ".join(self._allowed_commands) | ||
| base_desc += f" Preferred commands: {allowed_str}." | ||
| return base_desc | ||
| def get_parameters_schema(self) -> Dict[str, Any]: | ||
| return { | ||
| "type": "object", | ||
| "properties": { | ||
| "command": { | ||
| "type": "string", | ||
| "description": ( | ||
| "The bash command to execute. Use standard Unix commands. " | ||
| "For multi-step operations, chain commands with && or ;." | ||
| ) | ||
| }, | ||
| "restart": { | ||
| "type": "boolean", | ||
| "description": ( | ||
| "Set to true to restart the bash session, clearing all state " | ||
| "(working directory resets, environment variables cleared)." | ||
| ), | ||
| "default": False | ||
| } | ||
| }, | ||
| "required": [] # Neither is strictly required per Anthropic's bash tool spec | ||
| } | ||
| def validate_command(self, command: str) -> tuple[bool, Optional[str]]: | ||
| """ | ||
| Validate a command against safety rules. | ||
| Args: | ||
| command: The command to validate | ||
| Returns: | ||
| Tuple of (is_valid, error_message) | ||
| """ | ||
| if not command or not command.strip(): | ||
| return True, None # Empty command is valid (might just be restart) | ||
| # Dangerous patterns to block | ||
| dangerous_patterns = [ | ||
| "rm -rf /", | ||
| "rm -rf /*", | ||
| ":(){:|:&};:", # Fork bomb | ||
| "mkfs", | ||
| "dd if=", | ||
| "> /dev/sd", | ||
| "chmod -R 777 /", | ||
| "wget | sh", | ||
| "curl | sh", | ||
| ] | ||
| if self._safe_mode: | ||
| command_lower = command.lower() | ||
| for pattern in dangerous_patterns: | ||
| if pattern in command_lower: | ||
| return False, f"Command blocked: contains dangerous pattern '{pattern}'" | ||
| # Check allowed commands if specified | ||
| if self._allowed_commands: | ||
| command_parts = command.strip().split() | ||
| if command_parts: | ||
| base_command = command_parts[0] | ||
| if base_command not in self._allowed_commands: | ||
| return False, ( | ||
| f"Command '{base_command}' not in allowed list. " | ||
| f"Allowed: {self._allowed_commands}" | ||
| ) | ||
| return True, None | ||
| def execute_sync(self, command: str = "", restart: bool = False, **kwargs) -> str: | ||
| """ | ||
| Simulate bash command execution (no actual execution). | ||
| This method validates the command and returns a simulation message. | ||
| It does NOT actually execute the command. | ||
| Args: | ||
| command: The bash command (not executed) | ||
| restart: Whether to restart the session | ||
| Returns: | ||
| Simulation message describing what would happen | ||
| """ | ||
| if restart: | ||
| return "[SIMULATION] Bash session would be restarted. All state cleared." | ||
| if not command: | ||
| return "[SIMULATION] No command provided." | ||
| # Validate command | ||
| is_valid, error = self.validate_command(command) | ||
| if not is_valid: | ||
| return f"[SIMULATION] Command blocked: {error}" | ||
| return ( | ||
| f"[SIMULATION] Would execute command: {command}\n" | ||
| f"Note: Actual execution is disabled. This is a function calling simulation.\n" | ||
| f"To enable execution, handle the tool_call in your application code." | ||
| ) | ||
| async def execute(self, command: str = "", restart: bool = False, **kwargs) -> str: | ||
| """ | ||
| Async version of execute (same simulation behavior). | ||
| Args: | ||
| command: The bash command (not executed) | ||
| restart: Whether to restart the session | ||
| Returns: | ||
| Simulation message describing what would happen | ||
| """ | ||
| return self.execute_sync(command=command, restart=restart, **kwargs) | ||
| def get_execution_hints(self) -> Dict[str, Any]: | ||
| """ | ||
| Get hints for actual command execution (for callers who want to execute). | ||
| Returns: | ||
| Dict with execution configuration hints | ||
| """ | ||
| return { | ||
| "safe_mode": self._safe_mode, | ||
| "allowed_commands": self._allowed_commands, | ||
| "timeout": self._timeout, | ||
| "working_directory": self._working_directory, | ||
| "dangerous_patterns": [ | ||
| "rm -rf /", "mkfs", "dd if=", ":(){:|:&};:", | ||
| "chmod -R 777 /", "wget | sh", "curl | sh" | ||
| ] | ||
| } |
| """ | ||
| Capability registry for managing and instantiating capabilities. | ||
| """ | ||
| import logging | ||
| from typing import Any, Dict, List, Optional, Type, Union | ||
| from .base import BaseCapability | ||
| class CapabilityRegistry: | ||
| """ | ||
| Registry for managing capability types and instances. | ||
| The registry maintains a mapping of capability names to their implementation | ||
| classes, and provides methods to create and manage capability instances. | ||
| Example: | ||
| registry = CapabilityRegistry() | ||
| # Register custom capability | ||
| registry.register("my_tool", MyToolCapability) | ||
| # Create capability from config | ||
| cap = registry.create("web_search", {"enabled": True, "provider": "tavily"}) | ||
| # Get tool schema | ||
| schema = cap.get_tool_schema() | ||
| """ | ||
| # Class-level registry of capability types | ||
| _capability_types: Dict[str, Type[BaseCapability]] = {} | ||
| def __init__(self, logger: Optional[logging.Logger] = None): | ||
| """ | ||
| Initialize the registry. | ||
| Args: | ||
| logger: Optional logger instance | ||
| """ | ||
| self.logger = logger or logging.getLogger(__name__) | ||
| self._instances: Dict[str, BaseCapability] = {} | ||
| # Auto-register built-in capabilities | ||
| self._register_builtins() | ||
| def _register_builtins(self) -> None: | ||
| """Register built-in capability types.""" | ||
| # Import here to avoid circular imports | ||
| from .web_search import TavilySearchCapability | ||
| from .bash import BashCapability | ||
| self.register("web_search", TavilySearchCapability) | ||
| self.register("tavily_search", TavilySearchCapability) | ||
| self.register("bash", BashCapability) | ||
| def register(self, name: str, capability_class: Type[BaseCapability]) -> None: | ||
| """ | ||
| Register a capability type. | ||
| Args: | ||
| name: Name to register the capability under | ||
| capability_class: The capability class to register | ||
| """ | ||
| if not issubclass(capability_class, BaseCapability): | ||
| raise TypeError(f"capability_class must be a subclass of BaseCapability, got {type(capability_class)}") | ||
| self._capability_types[name] = capability_class | ||
| self.logger.debug(f"Registered capability type: {name} -> {capability_class.__name__}") | ||
| def create(self, name: str, config: Dict[str, Any]) -> BaseCapability: | ||
| """ | ||
| Create a capability instance from configuration. | ||
| Args: | ||
| name: Name of the capability type | ||
| config: Configuration dictionary | ||
| Returns: | ||
| Configured capability instance | ||
| Raises: | ||
| ValueError: If capability type is not registered | ||
| """ | ||
| if name not in self._capability_types: | ||
| raise ValueError( | ||
| f"Unknown capability type: '{name}'. " | ||
| f"Registered types: {list(self._capability_types.keys())}" | ||
| ) | ||
| capability_class = self._capability_types[name] | ||
| instance = capability_class(config) | ||
| self.logger.debug(f"Created capability instance: {instance}") | ||
| return instance | ||
| def create_from_config(self, capabilities_config: Dict[str, Any]) -> Dict[str, BaseCapability]: | ||
| """ | ||
| Create multiple capability instances from a capabilities config dict. | ||
| Args: | ||
| capabilities_config: Dict mapping capability names to their configs | ||
| Example: { | ||
| "web_search": {"enabled": True, "provider": "tavily"}, | ||
| "bash": {"enabled": True, "safe_mode": True} | ||
| } | ||
| Returns: | ||
| Dict mapping capability names to instances (only enabled ones) | ||
| """ | ||
| instances = {} | ||
| for name, config in capabilities_config.items(): | ||
| # Normalize config | ||
| if isinstance(config, bool): | ||
| config = {"enabled": config} | ||
| elif config is None: | ||
| config = {"enabled": False} | ||
| # Skip disabled capabilities | ||
| if not config.get("enabled", True): | ||
| self.logger.debug(f"Skipping disabled capability: {name}") | ||
| continue | ||
| try: | ||
| instance = self.create(name, config) | ||
| instances[name] = instance | ||
| except ValueError as e: | ||
| self.logger.warning(f"Failed to create capability '{name}': {e}") | ||
| return instances | ||
| def get_tool_schemas(self, capabilities: Dict[str, BaseCapability]) -> List[Dict[str, Any]]: | ||
| """ | ||
| Get OpenAI-compatible tool schemas for all capabilities. | ||
| Args: | ||
| capabilities: Dict of capability instances | ||
| Returns: | ||
| List of tool schemas | ||
| """ | ||
| schemas = [] | ||
| for cap in capabilities.values(): | ||
| if cap.is_enabled(): | ||
| schemas.append(cap.get_tool_schema()) | ||
| return schemas | ||
| def merge_capability_config( | ||
| self, | ||
| base: Dict[str, Any], | ||
| override: Union[bool, Dict[str, Any], None] | ||
| ) -> Dict[str, Any]: | ||
| """ | ||
| Merge capability configuration with override. | ||
| Merge rules: | ||
| 1. override=False -> {"enabled": False, ...base_params} | ||
| 2. override=True -> {"enabled": True, ...base_params} | ||
| 3. override=Dict -> {**base, **override} | ||
| 4. override=None -> base (no change) | ||
| Args: | ||
| base: Base configuration dict | ||
| override: Override value (bool, dict, or None) | ||
| Returns: | ||
| Merged configuration dict | ||
| """ | ||
| if override is None: | ||
| return base.copy() | ||
| if isinstance(override, bool): | ||
| result = base.copy() | ||
| result["enabled"] = override | ||
| return result | ||
| if isinstance(override, dict): | ||
| result = base.copy() | ||
| result.update(override) | ||
| return result | ||
| # Fallback: try to convert to bool | ||
| result = base.copy() | ||
| result["enabled"] = bool(override) | ||
| return result | ||
| def merge_capabilities_config( | ||
| self, | ||
| base_config: Dict[str, Any], | ||
| runtime_config: Dict[str, Any] | ||
| ) -> Dict[str, Any]: | ||
| """ | ||
| Merge base capabilities config with runtime overrides. | ||
| Args: | ||
| base_config: Base capabilities configuration from model config | ||
| runtime_config: Runtime overrides from invoke config | ||
| Returns: | ||
| Merged capabilities configuration | ||
| """ | ||
| result = {} | ||
| # Start with all base capabilities | ||
| all_keys = set(base_config.keys()) | set(runtime_config.keys()) | ||
| for key in all_keys: | ||
| base = base_config.get(key, {}) | ||
| override = runtime_config.get(key) | ||
| # Normalize base config | ||
| if isinstance(base, bool): | ||
| base = {"enabled": base} | ||
| elif base is None: | ||
| base = {"enabled": False} | ||
| result[key] = self.merge_capability_config(base, override) | ||
| return result | ||
| @classmethod | ||
| def get_registered_types(cls) -> List[str]: | ||
| """Get list of registered capability type names.""" | ||
| return list(cls._capability_types.keys()) |
| """ | ||
| Web search capability using Tavily API. | ||
| """ | ||
| import os | ||
| import json | ||
| import logging | ||
| from typing import Any, Dict, Optional | ||
| from .base import BaseCapability | ||
| # Optional Tavily import | ||
| try: | ||
| from tavily import TavilyClient, AsyncTavilyClient | ||
| TAVILY_AVAILABLE = True | ||
| except ImportError: | ||
| TAVILY_AVAILABLE = False | ||
| TavilyClient = None | ||
| AsyncTavilyClient = None | ||
| class TavilySearchCapability(BaseCapability): | ||
| """ | ||
| Web search capability using Tavily API. | ||
| Tavily is an AI-optimized search API that provides clean, relevant results | ||
| suitable for LLM consumption. | ||
| Configuration Options: | ||
| enabled (bool): Whether this capability is enabled (default: True) | ||
| api_key (str): Tavily API key (can also use TAVILY_API_KEY env var) | ||
| api_key_env (str): Environment variable name for API key (default: TAVILY_API_KEY) | ||
| max_results (int): Maximum number of search results (default: 5) | ||
| search_depth (str): Search depth - "basic" or "advanced" (default: "basic") | ||
| include_domains (list): List of domains to include in search | ||
| exclude_domains (list): List of domains to exclude from search | ||
| include_answer (bool): Include AI-generated answer (default: True) | ||
| include_raw_content (bool): Include raw HTML content (default: False) | ||
| Example: | ||
| config = { | ||
| "enabled": True, | ||
| "api_key_env": "TAVILY_API_KEY", | ||
| "max_results": 5, | ||
| "search_depth": "basic" | ||
| } | ||
| capability = TavilySearchCapability(config) | ||
| schema = capability.get_tool_schema() | ||
| """ | ||
| def __init__(self, config: Dict[str, Any]): | ||
| super().__init__(config) | ||
| self.logger = logging.getLogger(__name__) | ||
| # Get API key from config or environment | ||
| self._api_key = self._resolve_api_key() | ||
| # Initialize clients (lazy) | ||
| self._sync_client: Optional[TavilyClient] = None | ||
| self._async_client: Optional[AsyncTavilyClient] = None | ||
| def _resolve_api_key(self) -> Optional[str]: | ||
| """Resolve API key from config or environment.""" | ||
| # Direct api_key in config takes precedence | ||
| if api_key := self._raw_config.get("api_key"): | ||
| return api_key | ||
| # Then check environment variable | ||
| env_var = self._raw_config.get("api_key_env", "TAVILY_API_KEY") | ||
| return os.getenv(env_var) | ||
| def _get_sync_client(self) -> TavilyClient: | ||
| """Get or create synchronous Tavily client.""" | ||
| if not TAVILY_AVAILABLE: | ||
| raise ImportError("tavily package is required. Install with: pip install tavily-python") | ||
| if self._sync_client is None: | ||
| if not self._api_key: | ||
| raise ValueError( | ||
| "Tavily API key not found. Set TAVILY_API_KEY environment variable " | ||
| "or provide 'api_key' in capability config." | ||
| ) | ||
| self._sync_client = TavilyClient(api_key=self._api_key) | ||
| return self._sync_client | ||
| def _get_async_client(self) -> AsyncTavilyClient: | ||
| """Get or create asynchronous Tavily client.""" | ||
| if not TAVILY_AVAILABLE: | ||
| raise ImportError("tavily package is required. Install with: pip install tavily-python") | ||
| if self._async_client is None: | ||
| if not self._api_key: | ||
| raise ValueError( | ||
| "Tavily API key not found. Set TAVILY_API_KEY environment variable " | ||
| "or provide 'api_key' in capability config." | ||
| ) | ||
| self._async_client = AsyncTavilyClient(api_key=self._api_key) | ||
| return self._async_client | ||
| @property | ||
| def name(self) -> str: | ||
| return "web_search" | ||
| @property | ||
| def description(self) -> str: | ||
| return ( | ||
| "Search the web for real-time information. Use this tool when you need " | ||
| "current data such as weather, news, recent events, stock prices, or any " | ||
| "information that may have changed after your knowledge cutoff. " | ||
| "Provide a clear, specific search query for best results." | ||
| ) | ||
| def get_parameters_schema(self) -> Dict[str, Any]: | ||
| return { | ||
| "type": "object", | ||
| "properties": { | ||
| "query": { | ||
| "type": "string", | ||
| "description": "The search query. Be specific and include relevant context." | ||
| } | ||
| }, | ||
| "required": ["query"] | ||
| } | ||
| def _build_search_params(self, query: str) -> Dict[str, Any]: | ||
| """Build search parameters from config.""" | ||
| params = {"query": query} | ||
| # Add optional parameters from config | ||
| if max_results := self._raw_config.get("max_results"): | ||
| params["max_results"] = max_results | ||
| if search_depth := self._raw_config.get("search_depth"): | ||
| params["search_depth"] = search_depth | ||
| if include_domains := self._raw_config.get("include_domains"): | ||
| params["include_domains"] = include_domains | ||
| if exclude_domains := self._raw_config.get("exclude_domains"): | ||
| params["exclude_domains"] = exclude_domains | ||
| if include_answer := self._raw_config.get("include_answer"): | ||
| params["include_answer"] = include_answer | ||
| if include_raw_content := self._raw_config.get("include_raw_content"): | ||
| params["include_raw_content"] = include_raw_content | ||
| return params | ||
| def _format_results(self, response: Dict[str, Any]) -> str: | ||
| """Format Tavily response for LLM consumption.""" | ||
| output_parts = [] | ||
| # Include AI answer if available | ||
| if answer := response.get("answer"): | ||
| output_parts.append(f"Summary: {answer}\n") | ||
| # Format search results | ||
| results = response.get("results", []) | ||
| if results: | ||
| output_parts.append("Search Results:") | ||
| for i, result in enumerate(results, 1): | ||
| title = result.get("title", "No title") | ||
| url = result.get("url", "") | ||
| content = result.get("content", "") | ||
| output_parts.append(f"\n[{i}] {title}") | ||
| output_parts.append(f" URL: {url}") | ||
| if content: | ||
| # Truncate long content | ||
| content_preview = content[:500] + "..." if len(content) > 500 else content | ||
| output_parts.append(f" Content: {content_preview}") | ||
| if not output_parts: | ||
| return "No search results found." | ||
| return "\n".join(output_parts) | ||
| def execute_sync(self, query: str, **kwargs) -> str: | ||
| """ | ||
| Execute web search synchronously. | ||
| Args: | ||
| query: The search query | ||
| Returns: | ||
| Formatted search results as string | ||
| """ | ||
| try: | ||
| client = self._get_sync_client() | ||
| params = self._build_search_params(query) | ||
| self.logger.debug(f"Executing Tavily search: {params}") | ||
| response = client.search(**params) | ||
| return self._format_results(response) | ||
| except Exception as e: | ||
| self.logger.error(f"Tavily search failed: {e}") | ||
| return f"Search failed: {str(e)}" | ||
| async def execute(self, query: str, **kwargs) -> str: | ||
| """ | ||
| Execute web search asynchronously. | ||
| Args: | ||
| query: The search query | ||
| Returns: | ||
| Formatted search results as string | ||
| """ | ||
| try: | ||
| client = self._get_async_client() | ||
| params = self._build_search_params(query) | ||
| self.logger.debug(f"Executing async Tavily search: {params}") | ||
| response = await client.search(**params) | ||
| return self._format_results(response) | ||
| except Exception as e: | ||
| self.logger.error(f"Tavily async search failed: {e}") | ||
| return f"Search failed: {str(e)}" |
| """ | ||
| OAI Claude Chat Model - Claude via OpenAI-compatible interface with capability support. | ||
| This model wraps langchain_openai.ChatOpenAI to provide Claude access through | ||
| OpenAI-compatible endpoints (like ModelGate) with configurable capabilities | ||
| for web search, bash, and other tools via function calling. | ||
| """ | ||
| import os | ||
| import logging | ||
| from typing import Any, Dict, Iterator, List, Optional, AsyncIterator, Union | ||
| from langchain_openai.chat_models.base import ChatOpenAI | ||
| from langchain_core.messages import BaseMessage, AIMessage | ||
| from langchain_core.outputs import ChatResult | ||
| from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun | ||
| from pydantic import Field | ||
| from .capabilities import CapabilityRegistry, BaseCapability | ||
| from .tracing_manager import TracingManager | ||
| class OAIClaudeChatModel(ChatOpenAI): | ||
| """ | ||
| Claude model via OpenAI-compatible interface with capability support. | ||
| This model extends ChatOpenAI to work with OpenAI-compatible endpoints | ||
| (like ModelGate, LiteLLM, etc.) that proxy Claude models. It provides: | ||
| 1. **Pre-configured Capabilities**: Web search (Tavily), bash, etc. defined in config | ||
| 2. **Runtime Override**: Toggle or adjust capabilities per-request | ||
| 3. **Function Calling**: Returns tool_calls for caller to handle (no auto-execution) | ||
| **Key Differences from ClaudeChatModel:** | ||
| - Uses OpenAI-compatible API instead of native Anthropic SDK | ||
| - Capabilities are implemented via function calling, not native tools | ||
| - Better compatibility with proxy services that don't support Anthropic beta headers | ||
| **Configuration:** | ||
| In models_config.json: | ||
| ```json | ||
| { | ||
| "id": 72, | ||
| "provider": "custom-openai", | ||
| "type": "inference", | ||
| "deployment_name": "claude-opus-4-5@mg", | ||
| "base_url": "https://mg.aid.pub/v1", | ||
| "api_key": "sk-xxx", | ||
| "temperature": 0.5, | ||
| "max_tokens": 2048, | ||
| "capabilities": { | ||
| "web_search": { | ||
| "enabled": true, | ||
| "api_key_env": "TAVILY_API_KEY", | ||
| "max_results": 5 | ||
| }, | ||
| "bash": { | ||
| "enabled": true, | ||
| "safe_mode": true | ||
| } | ||
| } | ||
| } | ||
| ``` | ||
| **Runtime Override:** | ||
| ```python | ||
| # Disable web_search for this request | ||
| response = await model.ainvoke(messages, config={ | ||
| "configurable": { | ||
| "capabilities": {"web_search": False} | ||
| } | ||
| }) | ||
| # Override search parameters | ||
| response = await model.ainvoke(messages, config={ | ||
| "configurable": { | ||
| "capabilities": {"web_search": {"max_results": 10}} | ||
| } | ||
| }) | ||
| ``` | ||
| Attributes: | ||
| capabilities_config: Pre-configured capabilities from model config | ||
| logger: Optional logger instance | ||
| enable_tracing: Enable/disable tracing (auto-detect if None) | ||
| Example: | ||
| ```python | ||
| from crewplus.services import OAIClaudeChatModel | ||
| from langchain_core.messages import HumanMessage | ||
| # Initialize with capabilities | ||
| model = OAIClaudeChatModel( | ||
| model="claude-opus-4-5@mg", | ||
| base_url="https://mg.aid.pub/v1", | ||
| api_key="sk-xxx", | ||
| capabilities_config={ | ||
| "web_search": {"enabled": True, "api_key_env": "TAVILY_API_KEY"}, | ||
| "bash": {"enabled": True, "safe_mode": True} | ||
| } | ||
| ) | ||
| # Invoke - returns AIMessage with tool_calls if model wants to use tools | ||
| response = await model.ainvoke([HumanMessage(content="What's the weather in Tokyo?")]) | ||
| if response.tool_calls: | ||
| for tool_call in response.tool_calls: | ||
| print(f"Tool: {tool_call['name']}, Args: {tool_call['args']}") | ||
| ``` | ||
| """ | ||
| # Capability configuration | ||
| capabilities_config: Dict[str, Any] = Field( | ||
| default_factory=dict, | ||
| description="Pre-configured capabilities with their settings" | ||
| ) | ||
| # Tracing and logging | ||
| logger: Optional[logging.Logger] = Field( | ||
| default=None, | ||
| description="Optional logger instance", | ||
| exclude=True | ||
| ) | ||
| enable_tracing: Optional[bool] = Field( | ||
| default=None, | ||
| description="Enable tracing (auto-detect if None)" | ||
| ) | ||
| # Internal state | ||
| _capability_registry: Optional[CapabilityRegistry] = None | ||
| _capabilities: Optional[Dict[str, BaseCapability]] = None | ||
| _tracing_manager: Optional[TracingManager] = None | ||
| def __init__(self, **kwargs): | ||
| """ | ||
| Initialize OAIClaudeChatModel. | ||
| Args: | ||
| **kwargs: Arguments passed to ChatOpenAI, plus: | ||
| - capabilities_config: Dict of capability configurations | ||
| - logger: Optional logger instance | ||
| - enable_tracing: Optional bool to control tracing | ||
| """ | ||
| super().__init__(**kwargs) | ||
| # Initialize logger | ||
| if self.logger is None: | ||
| self.logger = logging.getLogger( | ||
| f"{self.__class__.__module__}.{self.__class__.__name__}" | ||
| ) | ||
| if not self.logger.handlers: | ||
| self.logger.addHandler(logging.StreamHandler()) | ||
| self.logger.setLevel(logging.INFO) | ||
| # Initialize capability registry and load capabilities | ||
| self._capability_registry = CapabilityRegistry(logger=self.logger) | ||
| self._load_capabilities() | ||
| # Initialize tracing manager | ||
| self._tracing_manager = TracingManager(self) | ||
| self.logger.info( | ||
| f"Initialized OAIClaudeChatModel (model={self.model_name}, " | ||
| f"base_url={self.openai_api_base}, " | ||
| f"capabilities={list(self._capabilities.keys())})" | ||
| ) | ||
| def _load_capabilities(self) -> None: | ||
| """Load and initialize capabilities from config.""" | ||
| self._capabilities = self._capability_registry.create_from_config( | ||
| self.capabilities_config | ||
| ) | ||
| if self._capabilities: | ||
| self.logger.debug( | ||
| f"Loaded capabilities: {[c.name for c in self._capabilities.values()]}" | ||
| ) | ||
| def get_model_identifier(self) -> str: | ||
| """Return a string identifying this model for tracing and logging.""" | ||
| return f"{self.__class__.__name__} (model='{self.model_name}')" | ||
| def _get_effective_capabilities( | ||
| self, | ||
| runtime_config: Optional[Dict[str, Any]] = None | ||
| ) -> Dict[str, BaseCapability]: | ||
| """ | ||
| Get effective capabilities after applying runtime overrides. | ||
| Args: | ||
| runtime_config: Runtime capability overrides from config | ||
| Returns: | ||
| Dict of effective capability instances | ||
| """ | ||
| if not runtime_config: | ||
| return self._capabilities.copy() | ||
| # Merge configurations | ||
| merged_config = self._capability_registry.merge_capabilities_config( | ||
| self.capabilities_config, | ||
| runtime_config | ||
| ) | ||
| # Create new instances with merged config | ||
| return self._capability_registry.create_from_config(merged_config) | ||
| def _get_capability_tools( | ||
| self, | ||
| capabilities: Dict[str, BaseCapability] | ||
| ) -> List[Dict[str, Any]]: | ||
| """ | ||
| Get OpenAI-compatible tool schemas for capabilities. | ||
| Args: | ||
| capabilities: Dict of capability instances | ||
| Returns: | ||
| List of tool schemas | ||
| """ | ||
| return self._capability_registry.get_tool_schemas(capabilities) | ||
| def _extract_runtime_capabilities( | ||
| self, | ||
| config: Optional[Dict[str, Any]] | ||
| ) -> Optional[Dict[str, Any]]: | ||
| """ | ||
| Extract capability overrides from invoke config. | ||
| Args: | ||
| config: The config dict passed to invoke/ainvoke | ||
| Returns: | ||
| Capability overrides dict or None | ||
| """ | ||
| if not config: | ||
| return None | ||
| # Check configurable first (preferred location) | ||
| configurable = config.get("configurable", {}) | ||
| if caps := configurable.get("capabilities"): | ||
| return caps | ||
| # Fallback to metadata | ||
| metadata = config.get("metadata", {}) | ||
| return metadata.get("capabilities") | ||
| def _prepare_invoke_kwargs( | ||
| self, | ||
| config: Optional[Dict[str, Any]], | ||
| kwargs: Dict[str, Any] | ||
| ) -> Dict[str, Any]: | ||
| """ | ||
| Prepare kwargs for invoke, adding capability tools. | ||
| Args: | ||
| config: Config dict from invoke | ||
| kwargs: Original kwargs | ||
| Returns: | ||
| Updated kwargs with tools added | ||
| """ | ||
| result = kwargs.copy() | ||
| # Get runtime capability overrides | ||
| runtime_caps = self._extract_runtime_capabilities(config) | ||
| # Get effective capabilities | ||
| effective_caps = self._get_effective_capabilities(runtime_caps) | ||
| # Get tool schemas | ||
| capability_tools = self._get_capability_tools(effective_caps) | ||
| if capability_tools: | ||
| # Merge with any existing tools | ||
| existing_tools = result.get("tools", []) | ||
| result["tools"] = capability_tools + list(existing_tools) | ||
| self.logger.debug(f"Added {len(capability_tools)} capability tools") | ||
| return result | ||
| def bind_tools( | ||
| self, | ||
| tools: List[Any], | ||
| **kwargs: Any | ||
| ) -> "OAIClaudeChatModel": | ||
| """ | ||
| Bind additional tools to this model. | ||
| This extends the parent bind_tools to also include capability tools. | ||
| Args: | ||
| tools: List of tools to bind | ||
| **kwargs: Additional arguments | ||
| Returns: | ||
| New model instance with tools bound | ||
| """ | ||
| # Get capability tools | ||
| capability_tools = self._get_capability_tools(self._capabilities) | ||
| # Combine with user tools | ||
| all_tools = capability_tools + list(tools) | ||
| self.logger.debug( | ||
| f"Binding tools: {len(capability_tools)} capability + {len(tools)} user" | ||
| ) | ||
| return super().bind_tools(all_tools, **kwargs) | ||
| def invoke( | ||
| self, | ||
| input: Any, | ||
| config: Optional[Dict[str, Any]] = None, | ||
| **kwargs: Any | ||
| ) -> AIMessage: | ||
| """ | ||
| Invoke the model with capability support. | ||
| Args: | ||
| input: Input messages | ||
| config: Optional config with capability overrides | ||
| **kwargs: Additional arguments | ||
| Returns: | ||
| AIMessage, potentially with tool_calls | ||
| """ | ||
| # Add tracing callbacks | ||
| config = self._tracing_manager.add_sync_callbacks_to_config(config) | ||
| # Prepare kwargs with capability tools | ||
| kwargs = self._prepare_invoke_kwargs(config, kwargs) | ||
| return super().invoke(input, config=config, **kwargs) | ||
| async def ainvoke( | ||
| self, | ||
| input: Any, | ||
| config: Optional[Dict[str, Any]] = None, | ||
| **kwargs: Any | ||
| ) -> AIMessage: | ||
| """ | ||
| Async invoke the model with capability support. | ||
| Args: | ||
| input: Input messages | ||
| config: Optional config with capability overrides | ||
| **kwargs: Additional arguments | ||
| Returns: | ||
| AIMessage, potentially with tool_calls | ||
| """ | ||
| # Add tracing callbacks | ||
| config = self._tracing_manager.add_async_callbacks_to_config(config) | ||
| # Prepare kwargs with capability tools | ||
| kwargs = self._prepare_invoke_kwargs(config, kwargs) | ||
| return await super().ainvoke(input, config=config, **kwargs) | ||
| def stream( | ||
| self, | ||
| input: Any, | ||
| config: Optional[Dict[str, Any]] = None, | ||
| **kwargs: Any | ||
| ) -> Iterator[Any]: | ||
| """ | ||
| Stream the model response with capability support. | ||
| Args: | ||
| input: Input messages | ||
| config: Optional config with capability overrides | ||
| **kwargs: Additional arguments | ||
| Yields: | ||
| Response chunks | ||
| """ | ||
| # Add tracing callbacks | ||
| config = self._tracing_manager.add_sync_callbacks_to_config(config) | ||
| # Prepare kwargs with capability tools | ||
| kwargs = self._prepare_invoke_kwargs(config, kwargs) | ||
| yield from super().stream(input, config=config, **kwargs) | ||
| async def astream( | ||
| self, | ||
| input: Any, | ||
| config: Optional[Dict[str, Any]] = None, | ||
| **kwargs: Any | ||
| ) -> AsyncIterator[Any]: | ||
| """ | ||
| Async stream the model response with capability support. | ||
| Args: | ||
| input: Input messages | ||
| config: Optional config with capability overrides | ||
| **kwargs: Additional arguments | ||
| Yields: | ||
| Response chunks | ||
| """ | ||
| # Add tracing callbacks | ||
| config = self._tracing_manager.add_async_callbacks_to_config(config) | ||
| # Prepare kwargs with capability tools | ||
| kwargs = self._prepare_invoke_kwargs(config, kwargs) | ||
| async for chunk in super().astream(input, config=config, **kwargs): | ||
| yield chunk | ||
| def get_capability(self, name: str) -> Optional[BaseCapability]: | ||
| """ | ||
| Get a capability instance by name. | ||
| Args: | ||
| name: Capability name | ||
| Returns: | ||
| Capability instance or None | ||
| """ | ||
| return self._capabilities.get(name) | ||
| def list_capabilities(self) -> List[str]: | ||
| """ | ||
| List all enabled capability names. | ||
| Returns: | ||
| List of capability names | ||
| """ | ||
| return list(self._capabilities.keys()) | ||
| def execute_capability(self, name: str, **kwargs) -> str: | ||
| """ | ||
| Execute a capability synchronously. | ||
| Args: | ||
| name: Capability name | ||
| **kwargs: Arguments for the capability | ||
| Returns: | ||
| Execution result as string | ||
| Raises: | ||
| ValueError: If capability not found | ||
| """ | ||
| cap = self.get_capability(name) | ||
| if not cap: | ||
| raise ValueError(f"Capability '{name}' not found") | ||
| return cap.execute_sync(**kwargs) | ||
| async def aexecute_capability(self, name: str, **kwargs) -> str: | ||
| """ | ||
| Execute a capability asynchronously. | ||
| Args: | ||
| name: Capability name | ||
| **kwargs: Arguments for the capability | ||
| Returns: | ||
| Execution result as string | ||
| Raises: | ||
| ValueError: If capability not found | ||
| """ | ||
| cap = self.get_capability(name) | ||
| if not cap: | ||
| raise ValueError(f"Capability '{name}' not found") | ||
| return await cap.execute(**kwargs) |
| #!/usr/bin/env python3 | ||
| """ | ||
| Bash 工具测试脚本 | ||
| 测试通过 OpenAI 兼容接口和原生 Anthropic API 调用 Bash 工具。 | ||
| 根据 Anthropic 官方文档: | ||
| - bash_20250124 是 schema-less 工具类型 | ||
| - 支持 command 和 restart 两个参数 | ||
| - 工具维护持久会话状态 | ||
| 使用方法: | ||
| python test_bash_tools.py | ||
| 依赖安装: | ||
| pip install openai anthropic | ||
| """ | ||
| import json | ||
| import sys | ||
| import subprocess | ||
| import platform | ||
| import tempfile | ||
| import os | ||
| # ============================================================================== | ||
| # 配置区域 | ||
| # ============================================================================== | ||
| CONFIG = { | ||
| "base_url": "https://mg.aid.pub/claude-proxy", | ||
| "openai_base_url": "https://mg.aid.pub/v1", | ||
| "api_key": "sk-991aeb35-043b-4722-b448-3050ba240ac4", | ||
| "model_anthropic": "claude-opus-4-5", | ||
| "model_openai": "Claude-Opus-4.5", | ||
| } | ||
| def print_separator(title: str): | ||
| """打印分隔线""" | ||
| print("\n" + "=" * 70) | ||
| print(f" {title}") | ||
| print("=" * 70 + "\n") | ||
| # ============================================================================== | ||
| # 测试1: OpenAI 兼容接口 + Function Calling 模拟 Bash 工具 | ||
| # ============================================================================== | ||
| def test_bash_tool_openai_compatible(): | ||
| """ | ||
| 使用 OpenAI 兼容接口测试 Bash 工具。 | ||
| 通过 Function Calling 模拟 Bash 工具,步骤: | ||
| 1. 定义 bash 工具 (function calling 格式) | ||
| 2. 发送请求让模型调用 bash 工具 | ||
| 3. 在本地执行命令 | ||
| 4. 将结果返回给模型 | ||
| """ | ||
| import openai | ||
| print_separator("测试 Bash 工具 (OpenAI 兼容接口)") | ||
| client = openai.OpenAI( | ||
| api_key=CONFIG["api_key"], | ||
| base_url=CONFIG["openai_base_url"] | ||
| ) | ||
| # 定义 Bash 工具 (OpenAI Function Calling 格式) | ||
| bash_tool = { | ||
| "type": "function", | ||
| "function": { | ||
| "name": "bash", | ||
| "description": "Execute a bash command in a persistent shell session. Use this for running system commands, scripts, and command-line operations.", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "command": { | ||
| "type": "string", | ||
| "description": "The bash command to execute" | ||
| }, | ||
| "restart": { | ||
| "type": "boolean", | ||
| "description": "Set to true to restart the bash session" | ||
| } | ||
| }, | ||
| "required": [] | ||
| } | ||
| } | ||
| } | ||
| # 测试场景: 让模型使用 bash 工具列出文件 | ||
| test_prompt = "请列出当前目录下的所有 Python 文件(使用 bash 命令)" | ||
| print(f"用户请求: {test_prompt}") | ||
| print("-" * 40) | ||
| try: | ||
| # 第一轮: 发送请求,让模型决定使用 bash 工具 | ||
| response = client.chat.completions.create( | ||
| model=CONFIG["model_openai"], | ||
| max_tokens=1024, | ||
| messages=[ | ||
| {"role": "user", "content": test_prompt} | ||
| ], | ||
| tools=[bash_tool], | ||
| tool_choice="auto" | ||
| ) | ||
| message = response.choices[0].message | ||
| if message.tool_calls: | ||
| print("模型请求调用 Bash 工具:") | ||
| for tool_call in message.tool_calls: | ||
| print(f" - 工具: {tool_call.function.name}") | ||
| print(f" - 参数: {tool_call.function.arguments}") | ||
| # 解析参数 | ||
| try: | ||
| args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | ||
| except json.JSONDecodeError: | ||
| args = {} | ||
| command = args.get("command", "") | ||
| restart = args.get("restart", False) | ||
| if restart: | ||
| print(" - 操作: 重启 bash 会话") | ||
| result = "Bash session restarted" | ||
| elif command: | ||
| print(f" - 执行命令: {command}") | ||
| # 安全检查 | ||
| dangerous_patterns = ['rm -rf /', 'format', ':(){:|:&};:', 'sudo rm', 'mkfs'] | ||
| is_safe = all(pattern not in command for pattern in dangerous_patterns) | ||
| if not is_safe: | ||
| result = "Error: Command blocked for safety reasons" | ||
| print(f" - 安全检查: 命令被阻止") | ||
| else: | ||
| # 执行命令 (跨平台兼容) | ||
| result = execute_bash_command(command) | ||
| print(f" - 执行结果:\n{result[:500]}") | ||
| else: | ||
| result = "Error: No command provided" | ||
| # 第二轮: 将工具结果返回给模型 | ||
| print("\n" + "-" * 40) | ||
| print("将结果返回给模型...") | ||
| messages = [ | ||
| {"role": "user", "content": test_prompt}, | ||
| message, | ||
| { | ||
| "role": "tool", | ||
| "tool_call_id": message.tool_calls[0].id, | ||
| "content": result | ||
| } | ||
| ] | ||
| final_response = client.chat.completions.create( | ||
| model=CONFIG["model_openai"], | ||
| max_tokens=1024, | ||
| messages=messages, | ||
| tools=[bash_tool] | ||
| ) | ||
| print(f"\n模型最终响应:\n{final_response.choices[0].message.content}") | ||
| return True | ||
| else: | ||
| print(f"模型直接响应 (未调用工具): {message.content}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试2: 原生 Anthropic API + bash_20250124 工具 | ||
| # ============================================================================== | ||
| def test_bash_tool_native_anthropic(): | ||
| """ | ||
| 使用原生 Anthropic SDK 测试 bash_20250124 工具。 | ||
| 根据官方文档,bash 工具是 schema-less 的, | ||
| 不需要提供 input_schema。 | ||
| 工具定义格式: | ||
| { | ||
| "type": "bash_20250124", | ||
| "name": "bash" | ||
| } | ||
| """ | ||
| print_separator("测试 Bash 工具 (原生 Anthropic API)") | ||
| try: | ||
| import anthropic | ||
| except ImportError: | ||
| print("需要安装 anthropic SDK: pip install anthropic") | ||
| return False | ||
| client = anthropic.Anthropic( | ||
| api_key=CONFIG["api_key"], | ||
| base_url=CONFIG["base_url"] | ||
| ) | ||
| # bash_20250124 是 schema-less 工具 | ||
| bash_tool = { | ||
| "type": "bash_20250124", | ||
| "name": "bash" | ||
| } | ||
| test_prompt = "列出当前目录下的文件" | ||
| print(f"用户请求: {test_prompt}") | ||
| print(f"使用工具类型: bash_20250124") | ||
| print("-" * 40) | ||
| try: | ||
| response = client.messages.create( | ||
| model=CONFIG["model_anthropic"], | ||
| max_tokens=1024, | ||
| tools=[bash_tool], | ||
| messages=[ | ||
| {"role": "user", "content": test_prompt} | ||
| ] | ||
| ) | ||
| print(f"响应 stop_reason: {response.stop_reason}") | ||
| for content in response.content: | ||
| if content.type == "text": | ||
| print(f"文本内容: {content.text}") | ||
| elif content.type == "tool_use": | ||
| print(f"工具调用: {content.name}") | ||
| print(f"工具输入: {json.dumps(content.input, ensure_ascii=False, indent=2)}") | ||
| # 获取命令 | ||
| command = content.input.get("command", "") | ||
| restart = content.input.get("restart", False) | ||
| if restart: | ||
| result = "Bash session restarted" | ||
| elif command: | ||
| result = execute_bash_command(command) | ||
| print(f"执行结果:\n{result[:500]}") | ||
| # 将结果返回给 Claude | ||
| final_response = client.messages.create( | ||
| model=CONFIG["model_anthropic"], | ||
| max_tokens=1024, | ||
| tools=[bash_tool], | ||
| messages=[ | ||
| {"role": "user", "content": test_prompt}, | ||
| {"role": "assistant", "content": response.content}, | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "tool_result", | ||
| "tool_use_id": content.id, | ||
| "content": result | ||
| } | ||
| ] | ||
| } | ||
| ] | ||
| ) | ||
| print("\n模型最终响应:") | ||
| for final_content in final_response.content: | ||
| if final_content.type == "text": | ||
| print(final_content.text) | ||
| else: | ||
| result = "Error: No command provided" | ||
| return True | ||
| except Exception as e: | ||
| error_str = str(e) | ||
| if "bash_20250124" in error_str or "unknown tool" in error_str.lower(): | ||
| print(f"API 不支持 bash_20250124 工具类型: {e}") | ||
| print("这可能是因为代理服务不支持原生 Anthropic 工具类型") | ||
| print("请使用 OpenAI 兼容接口的 Function Calling 方式") | ||
| return False | ||
| else: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试3: 多步骤 Bash 命令 (模拟持久会话) | ||
| # ============================================================================== | ||
| def test_bash_multi_step(): | ||
| """ | ||
| 测试多步骤 Bash 命令执行 (模拟持久会话)。 | ||
| 演示如何在多个命令之间保持状态: | ||
| 1. 创建文件 | ||
| 2. 读取文件 | ||
| 3. 统计文件信息 | ||
| """ | ||
| import openai | ||
| print_separator("测试多步骤 Bash 工具 (状态保持)") | ||
| client = openai.OpenAI( | ||
| api_key=CONFIG["api_key"], | ||
| base_url=CONFIG["openai_base_url"] | ||
| ) | ||
| bash_tool = { | ||
| "type": "function", | ||
| "function": { | ||
| "name": "bash", | ||
| "description": "Execute bash commands. Commands run in a persistent session that maintains state (working directory, environment variables, created files).", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "command": { | ||
| "type": "string", | ||
| "description": "The bash command to execute" | ||
| } | ||
| }, | ||
| "required": ["command"] | ||
| } | ||
| } | ||
| } | ||
| # 创建临时工作目录来模拟持久会话 | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| session_cwd = temp_dir | ||
| session_env = os.environ.copy() | ||
| test_prompt = """请执行以下操作: | ||
| 1. 创建一个名为 hello.txt 的文件,内容为 "Hello from Claude" | ||
| 2. 显示文件内容 | ||
| 3. 统计文件信息""" | ||
| print(f"用户请求: {test_prompt}") | ||
| print(f"工作目录: {session_cwd}") | ||
| print("-" * 40) | ||
| messages = [{"role": "user", "content": test_prompt}] | ||
| # 模拟多轮对话,执行多个命令 | ||
| max_iterations = 6 | ||
| iteration = 0 | ||
| while iteration < max_iterations: | ||
| iteration += 1 | ||
| print(f"\n--- 迭代 {iteration} ---") | ||
| response = client.chat.completions.create( | ||
| model=CONFIG["model_openai"], | ||
| max_tokens=1024, | ||
| messages=messages, | ||
| tools=[bash_tool], | ||
| tool_choice="auto" | ||
| ) | ||
| message = response.choices[0].message | ||
| messages.append(message) | ||
| if message.tool_calls: | ||
| for tool_call in message.tool_calls: | ||
| args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | ||
| command = args.get("command", "") | ||
| print(f"执行命令: {command}") | ||
| # 执行命令 | ||
| result = execute_bash_command(command, cwd=session_cwd, env=session_env) | ||
| print(f"结果: {result[:300]}") | ||
| messages.append({ | ||
| "role": "tool", | ||
| "tool_call_id": tool_call.id, | ||
| "content": result | ||
| }) | ||
| else: | ||
| # 模型给出了最终响应 | ||
| print(f"\n模型最终响应:\n{message.content}") | ||
| break | ||
| if iteration >= max_iterations: | ||
| print("达到最大迭代次数") | ||
| return True | ||
| # ============================================================================== | ||
| # 辅助函数 | ||
| # ============================================================================== | ||
| def execute_bash_command(command: str, cwd: str = None, env: dict = None, timeout: int = 30) -> str: | ||
| """ | ||
| 执行 bash 命令 (跨平台兼容) | ||
| Args: | ||
| command: 要执行的命令 | ||
| cwd: 工作目录 | ||
| env: 环境变量 | ||
| timeout: 超时时间(秒) | ||
| Returns: | ||
| 命令输出 (stdout + stderr) | ||
| """ | ||
| try: | ||
| # 跨平台命令转换 | ||
| if platform.system() == "Windows": | ||
| # 转换常见 Unix 命令为 Windows 等效命令 | ||
| original_command = command | ||
| if command.strip().startswith("ls"): | ||
| command = command.replace("ls -la", "dir").replace("ls -l", "dir").replace("ls", "dir") | ||
| elif command.strip().startswith("cat "): | ||
| command = command.replace("cat ", "type ") | ||
| elif command.strip().startswith("wc "): | ||
| # wc 在 Windows 上不存在 | ||
| parts = command.split() | ||
| if len(parts) >= 2: | ||
| filename = parts[-1] | ||
| command = f'powershell -Command "Get-Content {filename} | Measure-Object -Line -Word -Character"' | ||
| elif command.strip().startswith("pwd"): | ||
| command = "cd" | ||
| elif ">" in command and command.strip().startswith("echo"): | ||
| # echo "content" > file 在 Windows 上语法略有不同 | ||
| pass # Windows 的 echo 基本兼容 | ||
| if command != original_command: | ||
| print(f" (命令已转换: {original_command} -> {command})") | ||
| proc_result = subprocess.run( | ||
| command, | ||
| shell=True, | ||
| capture_output=True, | ||
| text=True, | ||
| timeout=timeout, | ||
| cwd=cwd, | ||
| env=env | ||
| ) | ||
| result = proc_result.stdout + proc_result.stderr | ||
| if not result.strip(): | ||
| result = "(命令执行成功,无输出)" | ||
| return result | ||
| except subprocess.TimeoutExpired: | ||
| return f"Error: Command timed out after {timeout} seconds" | ||
| except Exception as e: | ||
| return f"Error: {str(e)}" | ||
| def truncate_output(output: str, max_lines: int = 100) -> str: | ||
| """截断过长的输出""" | ||
| lines = output.split('\n') | ||
| if len(lines) > max_lines: | ||
| truncated = '\n'.join(lines[:max_lines]) | ||
| return f"{truncated}\n\n... Output truncated ({len(lines)} total lines) ..." | ||
| return output | ||
| # ============================================================================== | ||
| # 主函数 | ||
| # ============================================================================== | ||
| def main(): | ||
| """运行所有测试""" | ||
| print("\n" + "=" * 70) | ||
| print(" Bash 工具测试脚本") | ||
| print(" 测试配置:") | ||
| print(f" - Anthropic Base URL: {CONFIG['base_url']}") | ||
| print(f" - OpenAI Base URL: {CONFIG['openai_base_url']}") | ||
| print(f" - Model (Anthropic): {CONFIG['model_anthropic']}") | ||
| print(f" - Model (OpenAI): {CONFIG['model_openai']}") | ||
| print(f" - Platform: {platform.system()}") | ||
| print("=" * 70) | ||
| results = {} | ||
| # 测试1: OpenAI 兼容接口 | ||
| try: | ||
| results["OpenAI兼容接口"] = test_bash_tool_openai_compatible() | ||
| except Exception as e: | ||
| print(f"测试异常: {e}") | ||
| results["OpenAI兼容接口"] = False | ||
| # 测试2: 原生 Anthropic API | ||
| try: | ||
| results["原生Anthropic API"] = test_bash_tool_native_anthropic() | ||
| except Exception as e: | ||
| print(f"测试异常: {e}") | ||
| results["原生Anthropic API"] = False | ||
| # 测试3: 多步骤命令 | ||
| try: | ||
| results["多步骤命令"] = test_bash_multi_step() | ||
| except Exception as e: | ||
| print(f"测试异常: {e}") | ||
| results["多步骤命令"] = False | ||
| # 汇总结果 | ||
| print_separator("测试结果汇总") | ||
| for test_name, passed in results.items(): | ||
| status = "✅ 通过" if passed else "❌ 失败" | ||
| print(f" {test_name}: {status}") | ||
| print("\n" + "-" * 40) | ||
| print("结论:") | ||
| print(" 1. OpenAI 兼容接口: 使用 Function Calling 模拟 bash 工具") | ||
| print(" 2. 原生 Anthropic API: 使用 bash_20250124 工具类型") | ||
| print(" 3. bash 工具是 schema-less 的,只需指定 type 和 name") | ||
| print(" 4. 支持 command 和 restart 两个参数") | ||
| print("-" * 40) | ||
| all_passed = all(results.values()) | ||
| print(f"\n总体结果: {'✅ 全部通过' if all_passed else '⚠️ 部分测试失败'}") | ||
| return 0 if all_passed else 1 | ||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
| #!/usr/bin/env python3 | ||
| """ | ||
| Claude API 测试脚本 | ||
| 用于测试自定义 Claude 代理服务的 API 调用。 | ||
| 此脚本会分别测试两种认证方式: | ||
| 1. api_key 方式 (发送 x-api-key 头) | ||
| 2. auth_token 方式 (发送 Authorization: Bearer 头) | ||
| 使用方法: | ||
| python test_claude_api.py | ||
| 依赖安装: | ||
| pip install anthropic httpx | ||
| """ | ||
| import json | ||
| import sys | ||
| # ============================================================================== | ||
| # 配置区域 - 请根据实际情况修改 | ||
| # ============================================================================== | ||
| CONFIG = { | ||
| "provider": "custom", | ||
| "type": "inference", | ||
| "deployment_name": "claude-opus-4-56@mg", | ||
| "base_url": "https://mg.aid.pub/claude-proxy", | ||
| "api_key": "sk-0f7bfe7e-c9ff-461d-8813-b5d225385e88", | ||
| "temperature": 0.5, | ||
| "max_tokens": 20000, | ||
| "capabilities": { | ||
| "web_search": { | ||
| "enabled": True, | ||
| "version": "web-search-2025-03-05" | ||
| } | ||
| } | ||
| } | ||
| # 从 deployment_name 中提取模型名称 (去掉 @mg 后缀) | ||
| MODEL_NAME = CONFIG["deployment_name"].split("@")[0] | ||
| # ============================================================================== | ||
| # 测试函数 | ||
| # ============================================================================== | ||
| def print_separator(title: str): | ||
| """打印分隔线""" | ||
| print("\n" + "=" * 70) | ||
| print(f" {title}") | ||
| print("=" * 70 + "\n") | ||
| def test_with_test_test(): | ||
| """ | ||
| 使用 OpenAI 兼容接口测试 ModelGate 的 Claude 模型。 | ||
| 注意: | ||
| 1. ModelGate 的 /anthropic/v1/ 端点模型列表为空,需要使用 OpenAI 兼容接口 /v1/ | ||
| 2. ModelGate 不支持原生 Anthropic web_search_20250305 工具类型 | ||
| 3. 可以通过 Function Calling 方式实现 web search 工作流 | ||
| """ | ||
| import openai | ||
| client = openai.OpenAI( | ||
| api_key="sk-991aeb35-043b-4722-b448-3050ba240ac4", | ||
| base_url="https://mg.aid.pub/v1" | ||
| ) | ||
| # 打印可用模型列表 | ||
| print("=" * 50) | ||
| print("可用模型列表:") | ||
| print("-" * 40) | ||
| models = client.models.list() | ||
| for model in models.data: | ||
| print(f" - {model.id}") | ||
| print("-" * 40) | ||
| print() | ||
| # 测试1: 基础对话 | ||
| print("=" * 50) | ||
| print("测试1: 基础对话") | ||
| print("-" * 40) | ||
| response = client.chat.completions.create( | ||
| model="Claude-Opus-4.5", | ||
| max_tokens=1024, | ||
| messages=[ | ||
| {"role": "user", "content": "你好,请用一句话介绍自己"} | ||
| ] | ||
| ) | ||
| print(f"响应: {response.choices[0].message.content}") | ||
| print() | ||
| # 测试2: Function Calling (模拟 web search 工作流) | ||
| print("=" * 50) | ||
| print("测试2: Function Calling (web_search 工具)") | ||
| print("-" * 40) | ||
| web_search_tool = { | ||
| "type": "function", | ||
| "function": { | ||
| "name": "web_search", | ||
| "description": "Search the web for real-time information. Use this when you need current data like weather, news, or recent events.", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "query": { | ||
| "type": "string", | ||
| "description": "The search query" | ||
| } | ||
| }, | ||
| "required": ["query"] | ||
| } | ||
| } | ||
| } | ||
| response = client.chat.completions.create( | ||
| model="Claude-Opus-4.5", | ||
| max_tokens=1024, | ||
| messages=[ | ||
| {"role": "user", "content": "上海今天天气怎么样?请搜索最新信息"} | ||
| ], | ||
| tools=[web_search_tool], | ||
| tool_choice="any" # 强制调用至少一个工具 (Anthropic 格式: auto/any/tool) | ||
| ) | ||
| message = response.choices[0].message | ||
| if message.tool_calls: | ||
| print("模型请求调用工具:") | ||
| for tool_call in message.tool_calls: | ||
| print(f" - 工具: {tool_call.function.name}") | ||
| print(f" - 参数: {tool_call.function.arguments}") | ||
| # 解析工具参数 | ||
| try: | ||
| args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | ||
| except json.JSONDecodeError: | ||
| args = {} | ||
| # 如果模型没有提供 query,从用户消息中提取 | ||
| query = args.get("query", "上海今天天气") | ||
| print(f" - 搜索关键词: {query}") | ||
| # 模拟工具调用结果(实际应用中这里执行真正的搜索) | ||
| print() | ||
| print("模拟返回搜索结果并继续对话...") | ||
| mock_search_result = { | ||
| "query": query, | ||
| "results": [ | ||
| {"title": "上海天气预报", "snippet": "2026年2月9日 上海:多云,气温 8-15°C,东北风3级"} | ||
| ] | ||
| } | ||
| # 将工具结果返回给模型 | ||
| messages = [ | ||
| {"role": "user", "content": "上海今天天气怎么样?请搜索最新信息"}, | ||
| message, | ||
| { | ||
| "role": "tool", | ||
| "tool_call_id": message.tool_calls[0].id, | ||
| "content": json.dumps(mock_search_result, ensure_ascii=False) | ||
| } | ||
| ] | ||
| final_response = client.chat.completions.create( | ||
| model="Claude-Opus-4.5", | ||
| max_tokens=1024, | ||
| messages=messages, | ||
| tools=[web_search_tool] | ||
| ) | ||
| print(f"最终响应: {final_response.choices[0].message.content}") | ||
| else: | ||
| print(f"直接响应: {message.content}") | ||
| print() | ||
| print("=" * 50) | ||
| print("测试完成!") | ||
| print() | ||
| print("结论:") | ||
| print(" - ModelGate 支持 OpenAI 兼容的 Function Calling") | ||
| print(" - 不支持原生 Anthropic web_search_20250305 工具") | ||
| print(" - 可通过 Function Calling + 自定义搜索服务实现 web search") | ||
| def main(): | ||
| """运行所有测试""" | ||
| print("\n" + "=" * 70) | ||
| print(" Claude API 测试脚本") | ||
| print(" 测试配置:") | ||
| print(f" - Base URL: {CONFIG['base_url']}") | ||
| print(f" - Model: {MODEL_NAME}") | ||
| print(f" - API Key: {CONFIG['api_key'][:20]}...") | ||
| print("=" * 70) | ||
| results = {} | ||
| results["test_with_test_test"] = test_with_test_test() | ||
| # 汇总结果 | ||
| print_separator("测试结果汇总") | ||
| for test_name, passed in results.items(): | ||
| status = "✅ 通过" if passed else "❌ 失败" | ||
| print(f" {test_name}: {status}") | ||
| # 返回是否全部通过 | ||
| all_passed = all(results.values()) | ||
| print(f"\n总体结果: {'✅ 全部通过' if all_passed else '❌ 部分失败'}") | ||
| return 0 if all_passed else 1 | ||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
| #!/usr/bin/env python3 | ||
| """ | ||
| Claude API 测试脚本 (LangChain OpenAI 兼容接口) | ||
| 使用 langchain_openai.ChatOpenAI 通过 OpenAI 兼容接口调用 Claude 模型。 | ||
| 测试内容: | ||
| 1. ainvoke / astream - 基础异步对话 | ||
| 2. Web Search Tool - Function Calling + 自定义搜索服务 | ||
| 3. Bash Tool - Function Calling 格式 | ||
| 使用方法: | ||
| python test_claude_langchain_openai.py | ||
| 依赖安装: | ||
| pip install langchain-openai langchain-core | ||
| """ | ||
| import asyncio | ||
| import json | ||
| import sys | ||
| import subprocess | ||
| import platform | ||
| import tempfile | ||
| import os | ||
| from typing import List, Any | ||
| from langchain_openai.chat_models.base import ChatOpenAI | ||
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage | ||
| from langchain_core.tools import tool | ||
| # ============================================================================== | ||
| # 配置区域 | ||
| # ============================================================================== | ||
| CONFIG = { | ||
| "openai_base_url": "https://mg.aid.pub/v1", | ||
| "api_key": "sk-991aeb35-043b-4722-b448-3050ba240ac4", | ||
| "model": "claude-opus-4-5@mg", | ||
| "temperature": 0.5, | ||
| "max_tokens": 2048, | ||
| } | ||
| def print_separator(title: str): | ||
| """打印分隔线""" | ||
| print("\n" + "=" * 70) | ||
| print(f" {title}") | ||
| print("=" * 70 + "\n") | ||
| def get_chat_model() -> ChatOpenAI: | ||
| """获取配置好的 ChatOpenAI 实例""" | ||
| return ChatOpenAI( | ||
| api_key=CONFIG["api_key"], | ||
| base_url=CONFIG["openai_base_url"], | ||
| model=CONFIG["model"], | ||
| temperature=CONFIG["temperature"], | ||
| max_tokens=CONFIG["max_tokens"], | ||
| ) | ||
| # ============================================================================== | ||
| # 工具定义 | ||
| # ============================================================================== | ||
| @tool | ||
| def web_search(query: str) -> str: | ||
| """ | ||
| Search the web for real-time information. | ||
| Use this when you need current data like weather, news, or recent events. | ||
| Args: | ||
| query: The search query string | ||
| """ | ||
| # 模拟搜索结果 (实际应用中这里执行真正的搜索) | ||
| mock_results = { | ||
| "query": query, | ||
| "results": [ | ||
| { | ||
| "title": f"搜索结果: {query}", | ||
| "snippet": f"这是关于 '{query}' 的模拟搜索结果。2026年2月9日更新。", | ||
| "url": "https://example.com/search" | ||
| } | ||
| ] | ||
| } | ||
| return json.dumps(mock_results, ensure_ascii=False) | ||
| @tool | ||
| def bash(command: str, restart: bool = False) -> str: | ||
| """ | ||
| Execute a bash command in a persistent shell session. | ||
| Use this for running system commands, scripts, and command-line operations. | ||
| Args: | ||
| command: The bash command to execute | ||
| restart: Set to true to restart the bash session | ||
| """ | ||
| if restart: | ||
| return "Bash session restarted" | ||
| if not command: | ||
| return "Error: No command provided" | ||
| # 安全检查 | ||
| dangerous_patterns = ['rm -rf /', 'format', ':(){:|:&};:', 'sudo rm', 'mkfs'] | ||
| if any(pattern in command for pattern in dangerous_patterns): | ||
| return "Error: Command blocked for safety reasons" | ||
| return execute_bash_command(command) | ||
| def execute_bash_command(command: str, cwd: str = None, timeout: int = 30) -> str: | ||
| """执行 bash 命令 (跨平台兼容)""" | ||
| try: | ||
| # 跨平台命令转换 | ||
| if platform.system() == "Windows": | ||
| original_command = command | ||
| if command.strip().startswith("ls"): | ||
| command = command.replace("ls -la", "dir").replace("ls -l", "dir").replace("ls", "dir") | ||
| elif command.strip().startswith("cat "): | ||
| command = command.replace("cat ", "type ") | ||
| elif command.strip().startswith("wc "): | ||
| parts = command.split() | ||
| if len(parts) >= 2: | ||
| filename = parts[-1] | ||
| command = f'powershell -Command "Get-Content {filename} | Measure-Object -Line -Word -Character"' | ||
| elif command.strip().startswith("pwd"): | ||
| command = "cd" | ||
| if command != original_command: | ||
| print(f" (命令已转换: {original_command} -> {command})") | ||
| proc_result = subprocess.run( | ||
| command, | ||
| shell=True, | ||
| capture_output=True, | ||
| text=True, | ||
| timeout=timeout, | ||
| cwd=cwd, | ||
| ) | ||
| result = proc_result.stdout + proc_result.stderr | ||
| if not result.strip(): | ||
| result = "(命令执行成功,无输出)" | ||
| return result | ||
| except subprocess.TimeoutExpired: | ||
| return f"Error: Command timed out after {timeout} seconds" | ||
| except Exception as e: | ||
| return f"Error: {str(e)}" | ||
| # ============================================================================== | ||
| # 测试1: ainvoke 基础异步调用 | ||
| # ============================================================================== | ||
| async def test_ainvoke_basic(): | ||
| """测试 ainvoke 基础异步调用""" | ||
| print_separator("测试1: ainvoke 基础异步调用") | ||
| chat = get_chat_model() | ||
| messages = [ | ||
| HumanMessage(content="你好,请用一句话介绍自己") | ||
| ] | ||
| try: | ||
| print("发送请求...") | ||
| response = await chat.ainvoke(messages) | ||
| print(f"响应类型: {type(response).__name__}") | ||
| print(f"响应内容: {response.content}") | ||
| print(f"响应元数据: {response.response_metadata}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试2: astream 异步流式调用 | ||
| # ============================================================================== | ||
| async def test_astream_basic(): | ||
| """测试 astream 异步流式调用""" | ||
| print_separator("测试2: astream 异步流式调用") | ||
| chat = get_chat_model() | ||
| messages = [ | ||
| HumanMessage(content="请用3句话描述春天的美景") | ||
| ] | ||
| try: | ||
| print("开始流式输出...") | ||
| print("-" * 40) | ||
| full_content = "" | ||
| chunk_count = 0 | ||
| async for chunk in chat.astream(messages): | ||
| chunk_count += 1 | ||
| if chunk.content: | ||
| print(chunk.content, end="", flush=True) | ||
| full_content += chunk.content | ||
| print("\n" + "-" * 40) | ||
| print(f"流式输出完成,共 {chunk_count} 个 chunk") | ||
| print(f"完整内容长度: {len(full_content)} 字符") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试3: Web Search Tool (Function Calling) | ||
| # ============================================================================== | ||
| async def test_web_search_tool(): | ||
| """测试 Web Search 工具 (Function Calling + 自定义搜索服务)""" | ||
| print_separator("测试3: Web Search Tool (Function Calling)") | ||
| chat = get_chat_model() | ||
| # 绑定工具 | ||
| chat_with_tools = chat.bind_tools([web_search]) | ||
| messages = [ | ||
| HumanMessage(content="上海今天天气怎么样?请搜索最新信息") | ||
| ] | ||
| try: | ||
| print("发送请求 (带 web_search 工具)...") | ||
| response = await chat_with_tools.ainvoke(messages) | ||
| print(f"响应类型: {type(response).__name__}") | ||
| if response.tool_calls: | ||
| print("模型请求调用工具:") | ||
| for tc in response.tool_calls: | ||
| print(f" - 工具名称: {tc['name']}") | ||
| print(f" - 工具参数: {tc['args']}") | ||
| print(f" - 工具调用ID: {tc['id']}") | ||
| # 执行工具调用 | ||
| tool_call = response.tool_calls[0] | ||
| tool_result = web_search.invoke(tool_call['args']) | ||
| print(f"\n工具执行结果: {tool_result}") | ||
| # 将工具结果返回给模型 | ||
| print("\n将结果返回给模型...") | ||
| messages_with_result = [ | ||
| HumanMessage(content="上海今天天气怎么样?请搜索最新信息"), | ||
| response, | ||
| ToolMessage(content=tool_result, tool_call_id=tool_call['id']) | ||
| ] | ||
| final_response = await chat_with_tools.ainvoke(messages_with_result) | ||
| print(f"\n模型最终响应: {final_response.content}") | ||
| return True | ||
| else: | ||
| print(f"模型直接响应 (未调用工具): {response.content}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试4: Web Search Tool (astream 流式) | ||
| # ============================================================================== | ||
| async def test_web_search_tool_stream(): | ||
| """测试 Web Search 工具流式调用""" | ||
| print_separator("测试4: Web Search Tool (astream 流式)") | ||
| chat = get_chat_model() | ||
| chat_with_tools = chat.bind_tools([web_search]) | ||
| messages = [ | ||
| HumanMessage(content="请搜索北京今天的空气质量") | ||
| ] | ||
| try: | ||
| print("发送流式请求 (带 web_search 工具)...") | ||
| tool_calls_accumulated = [] | ||
| content_accumulated = "" | ||
| async for chunk in chat_with_tools.astream(messages): | ||
| if chunk.content: | ||
| print(chunk.content, end="", flush=True) | ||
| content_accumulated += chunk.content | ||
| if chunk.tool_call_chunks: | ||
| for tc_chunk in chunk.tool_call_chunks: | ||
| # 累积工具调用信息 | ||
| if tc_chunk.index is not None: | ||
| while len(tool_calls_accumulated) <= tc_chunk.index: | ||
| tool_calls_accumulated.append({"name": "", "args": "", "id": ""}) | ||
| if tc_chunk.name: | ||
| tool_calls_accumulated[tc_chunk.index]["name"] = tc_chunk.name | ||
| if tc_chunk.args: | ||
| tool_calls_accumulated[tc_chunk.index]["args"] += tc_chunk.args | ||
| if tc_chunk.id: | ||
| tool_calls_accumulated[tc_chunk.index]["id"] = tc_chunk.id | ||
| print() | ||
| if tool_calls_accumulated: | ||
| print("\n检测到工具调用:") | ||
| for idx, tc in enumerate(tool_calls_accumulated): | ||
| print(f" [{idx}] 工具: {tc['name']}, 参数: {tc['args']}, ID: {tc['id']}") | ||
| # 执行工具 | ||
| try: | ||
| args = json.loads(tc['args']) if tc['args'] else {} | ||
| tool_result = web_search.invoke(args) | ||
| print(f" 工具执行结果: {tool_result[:200]}...") | ||
| except Exception as e: | ||
| print(f" 工具执行失败: {e}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试5: Bash Tool (Function Calling) | ||
| # ============================================================================== | ||
| async def test_bash_tool(): | ||
| """测试 Bash 工具 (Function Calling 格式)""" | ||
| print_separator("测试5: Bash Tool (Function Calling)") | ||
| chat = get_chat_model() | ||
| chat_with_tools = chat.bind_tools([bash]) | ||
| messages = [ | ||
| HumanMessage(content="请列出当前目录下的所有文件") | ||
| ] | ||
| try: | ||
| print("发送请求 (带 bash 工具)...") | ||
| response = await chat_with_tools.ainvoke(messages) | ||
| if response.tool_calls: | ||
| print("模型请求调用工具:") | ||
| for tc in response.tool_calls: | ||
| print(f" - 工具名称: {tc['name']}") | ||
| print(f" - 工具参数: {tc['args']}") | ||
| # 执行工具调用 | ||
| tool_call = response.tool_calls[0] | ||
| tool_result = bash.invoke(tool_call['args']) | ||
| print(f"\n工具执行结果:\n{tool_result[:500]}") | ||
| # 将工具结果返回给模型 | ||
| print("\n将结果返回给模型...") | ||
| messages_with_result = [ | ||
| HumanMessage(content="请列出当前目录下的所有文件"), | ||
| response, | ||
| ToolMessage(content=tool_result, tool_call_id=tool_call['id']) | ||
| ] | ||
| final_response = await chat_with_tools.ainvoke(messages_with_result) | ||
| print(f"\n模型最终响应: {final_response.content}") | ||
| return True | ||
| else: | ||
| print(f"模型直接响应 (未调用工具): {response.content}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试6: 多步骤 Bash 命令 | ||
| # ============================================================================== | ||
| async def test_bash_multi_step(): | ||
| """测试多步骤 Bash 命令执行""" | ||
| print_separator("测试6: 多步骤 Bash Tool") | ||
| chat = get_chat_model() | ||
| chat_with_tools = chat.bind_tools([bash]) | ||
| # 创建临时工作目录 | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| messages = [ | ||
| SystemMessage(content=f"当前工作目录是: {temp_dir}"), | ||
| HumanMessage(content="""请执行以下操作: | ||
| 1. 创建一个名为 hello.txt 的文件,内容为 "Hello from Claude" | ||
| 2. 显示文件内容 | ||
| 3. 统计文件信息""") | ||
| ] | ||
| print(f"工作目录: {temp_dir}") | ||
| print("-" * 40) | ||
| max_iterations = 6 | ||
| iteration = 0 | ||
| try: | ||
| while iteration < max_iterations: | ||
| iteration += 1 | ||
| print(f"\n--- 迭代 {iteration} ---") | ||
| response = await chat_with_tools.ainvoke(messages) | ||
| messages.append(response) | ||
| if response.tool_calls: | ||
| for tc in response.tool_calls: | ||
| print(f"执行命令: {tc['args']}") | ||
| # 执行命令 (在临时目录中) | ||
| command = tc['args'].get('command', '') | ||
| result = execute_bash_command(command, cwd=temp_dir) | ||
| print(f"结果: {result[:300]}") | ||
| messages.append(ToolMessage( | ||
| content=result, | ||
| tool_call_id=tc['id'] | ||
| )) | ||
| else: | ||
| # 模型给出了最终响应 | ||
| print(f"\n模型最终响应:\n{response.content}") | ||
| break | ||
| if iteration >= max_iterations: | ||
| print("达到最大迭代次数") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 测试7: 同时使用多个工具 | ||
| # ============================================================================== | ||
| async def test_multiple_tools(): | ||
| """测试同时绑定多个工具""" | ||
| print_separator("测试7: 多工具绑定测试") | ||
| chat = get_chat_model() | ||
| chat_with_tools = chat.bind_tools([web_search, bash]) | ||
| messages = [ | ||
| HumanMessage(content="请先搜索今天的新闻标题,然后用bash命令显示当前时间") | ||
| ] | ||
| try: | ||
| print("发送请求 (绑定 web_search + bash 两个工具)...") | ||
| response = await chat_with_tools.ainvoke(messages) | ||
| if response.tool_calls: | ||
| print(f"模型请求调用 {len(response.tool_calls)} 个工具:") | ||
| for tc in response.tool_calls: | ||
| print(f" - 工具: {tc['name']}, 参数: {tc['args']}") | ||
| # 执行所有工具调用 | ||
| messages.append(response) | ||
| for tc in response.tool_calls: | ||
| if tc['name'] == 'web_search': | ||
| result = web_search.invoke(tc['args']) | ||
| elif tc['name'] == 'bash': | ||
| result = bash.invoke(tc['args']) | ||
| else: | ||
| result = f"Unknown tool: {tc['name']}" | ||
| print(f"\n{tc['name']} 结果: {result[:200]}") | ||
| messages.append(ToolMessage(content=result, tool_call_id=tc['id'])) | ||
| # 获取最终响应 | ||
| final_response = await chat_with_tools.ainvoke(messages) | ||
| print(f"\n模型最终响应: {final_response.content}") | ||
| return True | ||
| else: | ||
| print(f"模型直接响应: {response.content}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"测试失败: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # 主函数 | ||
| # ============================================================================== | ||
| async def run_all_tests(): | ||
| """运行所有测试""" | ||
| print("\n" + "=" * 70) | ||
| print(" Claude API 测试脚本 (LangChain OpenAI 兼容接口)") | ||
| print(" 测试配置:") | ||
| print(f" - Base URL: {CONFIG['openai_base_url']}") | ||
| print(f" - Model: {CONFIG['model']}") | ||
| print(f" - Platform: {platform.system()}") | ||
| print("=" * 70) | ||
| results = {} | ||
| # 测试1: ainvoke 基础调用 | ||
| try: | ||
| results["ainvoke_basic"] = await test_ainvoke_basic() | ||
| except Exception as e: | ||
| print(f"测试 ainvoke_basic 异常: {e}") | ||
| results["ainvoke_basic"] = False | ||
| # 测试2: astream 流式调用 | ||
| try: | ||
| results["astream_basic"] = await test_astream_basic() | ||
| except Exception as e: | ||
| print(f"测试 astream_basic 异常: {e}") | ||
| results["astream_basic"] = False | ||
| # 测试3: Web Search Tool | ||
| try: | ||
| results["web_search_tool"] = await test_web_search_tool() | ||
| except Exception as e: | ||
| print(f"测试 web_search_tool 异常: {e}") | ||
| results["web_search_tool"] = False | ||
| # 测试4: Web Search Tool (stream) | ||
| try: | ||
| results["web_search_stream"] = await test_web_search_tool_stream() | ||
| except Exception as e: | ||
| print(f"测试 web_search_stream 异常: {e}") | ||
| results["web_search_stream"] = False | ||
| # 测试5: Bash Tool | ||
| try: | ||
| results["bash_tool"] = await test_bash_tool() | ||
| except Exception as e: | ||
| print(f"测试 bash_tool 异常: {e}") | ||
| results["bash_tool"] = False | ||
| # 测试6: 多步骤 Bash | ||
| try: | ||
| results["bash_multi_step"] = await test_bash_multi_step() | ||
| except Exception as e: | ||
| print(f"测试 bash_multi_step 异常: {e}") | ||
| results["bash_multi_step"] = False | ||
| # 测试7: 多工具 | ||
| try: | ||
| results["multiple_tools"] = await test_multiple_tools() | ||
| except Exception as e: | ||
| print(f"测试 multiple_tools 异常: {e}") | ||
| results["multiple_tools"] = False | ||
| # 汇总结果 | ||
| print_separator("测试结果汇总") | ||
| for test_name, passed in results.items(): | ||
| status = "PASS" if passed else "FAIL" | ||
| print(f" {test_name}: {status}") | ||
| print("\n" + "-" * 40) | ||
| print("测试说明:") | ||
| print(" 1. ainvoke/astream: 基础异步调用和流式输出") | ||
| print(" 2. web_search: Function Calling + 自定义搜索服务") | ||
| print(" 3. bash: Function Calling 格式执行系统命令") | ||
| print(" 4. 多工具: 同时绑定多个工具供模型选择") | ||
| print("-" * 40) | ||
| all_passed = all(results.values()) | ||
| print(f"\n总体结果: {'ALL PASSED' if all_passed else 'SOME FAILED'}") | ||
| return 0 if all_passed else 1 | ||
| def main(): | ||
| """主入口""" | ||
| return asyncio.run(run_all_tests()) | ||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
| #!/usr/bin/env python3 | ||
| """ | ||
| OAIClaudeChatModel Test Suite | ||
| Tests for the OpenAI-compatible Claude chat model with capability support. | ||
| Tests include: | ||
| 1. Basic ainvoke/astream | ||
| 2. Web Search capability (Tavily via function calling) | ||
| 3. Bash capability (function calling simulation) | ||
| 4. Runtime capability override | ||
| 5. Multiple capabilities combined | ||
| Usage: | ||
| python test_oai_claude_chat_model.py | ||
| Dependencies: | ||
| pip install langchain-openai langchain-core tavily-python | ||
| """ | ||
| import asyncio | ||
| import json | ||
| import sys | ||
| import os | ||
| import platform | ||
| from typing import Dict, Any, List | ||
| # Add parent path for imports | ||
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage | ||
| from crewplus.services import OAIClaudeChatModel | ||
| from crewplus.services.capabilities import CapabilityRegistry, TavilySearchCapability, BashCapability | ||
| # ============================================================================== | ||
| # Configuration | ||
| # ============================================================================== | ||
| CONFIG = { | ||
| "base_url": "https://mg.aid.pub/v1", | ||
| "api_key": "sk-991aeb35-043b-4722-b448-3050ba240ac4", | ||
| "model": "claude-opus-4-5@mg", | ||
| "temperature": 0.5, | ||
| "max_tokens": 2048, | ||
| } | ||
| def print_separator(title: str): | ||
| """Print a separator line with title.""" | ||
| print("\n" + "=" * 70) | ||
| print(f" {title}") | ||
| print("=" * 70 + "\n") | ||
| def get_model(capabilities_config: Dict[str, Any] = None) -> OAIClaudeChatModel: | ||
| """Create a configured OAIClaudeChatModel instance.""" | ||
| return OAIClaudeChatModel( | ||
| model=CONFIG["model"], | ||
| api_key=CONFIG["api_key"], | ||
| base_url=CONFIG["base_url"], | ||
| temperature=CONFIG["temperature"], | ||
| max_tokens=CONFIG["max_tokens"], | ||
| capabilities_config=capabilities_config or {}, | ||
| ) | ||
| # ============================================================================== | ||
| # Test 1: Basic ainvoke | ||
| # ============================================================================== | ||
| async def test_ainvoke_basic(): | ||
| """Test basic async invocation without capabilities.""" | ||
| print_separator("Test 1: Basic ainvoke (no capabilities)") | ||
| model = get_model() | ||
| messages = [HumanMessage(content="Hello! Please introduce yourself in one sentence.")] | ||
| try: | ||
| print("Sending request...") | ||
| response = await model.ainvoke(messages) | ||
| print(f"Response type: {type(response).__name__}") | ||
| print(f"Content: {response.content}") | ||
| print(f"Tool calls: {response.tool_calls if hasattr(response, 'tool_calls') else 'None'}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 2: Basic astream | ||
| # ============================================================================== | ||
| async def test_astream_basic(): | ||
| """Test basic async streaming without capabilities.""" | ||
| print_separator("Test 2: Basic astream (no capabilities)") | ||
| model = get_model() | ||
| messages = [HumanMessage(content="Describe spring in 3 sentences.")] | ||
| try: | ||
| print("Starting stream...") | ||
| print("-" * 40) | ||
| full_content = "" | ||
| chunk_count = 0 | ||
| async for chunk in model.astream(messages): | ||
| chunk_count += 1 | ||
| if chunk.content: | ||
| print(chunk.content, end="", flush=True) | ||
| full_content += chunk.content | ||
| print("\n" + "-" * 40) | ||
| print(f"Stream complete: {chunk_count} chunks, {len(full_content)} chars") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 3: Web Search Capability | ||
| # ============================================================================== | ||
| async def test_web_search_capability(): | ||
| """Test web search capability via function calling.""" | ||
| print_separator("Test 3: Web Search Capability (Function Calling)") | ||
| model = get_model(capabilities_config={ | ||
| "web_search": { | ||
| "enabled": True, | ||
| "api_key_env": "TAVILY_API_KEY", | ||
| "max_results": 5 | ||
| } | ||
| }) | ||
| messages = [HumanMessage(content="What is the current weather in Tokyo? Please search for this information.")] | ||
| try: | ||
| print(f"Enabled capabilities: {model.list_capabilities()}") | ||
| print("Sending request with web_search tool...") | ||
| response = await model.ainvoke(messages) | ||
| print(f"\nResponse type: {type(response).__name__}") | ||
| print(f"Content: {response.content[:200] if response.content else '(none)'}...") | ||
| if response.tool_calls: | ||
| print(f"\nTool calls detected: {len(response.tool_calls)}") | ||
| for i, tc in enumerate(response.tool_calls): | ||
| print(f" [{i}] Tool: {tc['name']}") | ||
| print(f" Args: {tc['args']}") | ||
| print(f" ID: {tc['id']}") | ||
| # Simulate tool execution (mock response) | ||
| print("\nSimulating tool execution...") | ||
| mock_result = json.dumps({ | ||
| "query": "current weather in Tokyo", | ||
| "results": [ | ||
| {"title": "Tokyo Weather", "content": "Currently 15C, partly cloudy, humidity 60%"} | ||
| ] | ||
| }, ensure_ascii=False) | ||
| # Continue conversation with tool result | ||
| messages_with_result = [ | ||
| messages[0], | ||
| response, | ||
| ToolMessage(content=mock_result, tool_call_id=response.tool_calls[0]['id']) | ||
| ] | ||
| final_response = await model.ainvoke(messages_with_result) | ||
| print(f"\nFinal response: {final_response.content}") | ||
| return True | ||
| else: | ||
| print("No tool calls - model responded directly") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 4: Bash Capability | ||
| # ============================================================================== | ||
| async def test_bash_capability(): | ||
| """Test bash capability via function calling (simulation only).""" | ||
| print_separator("Test 4: Bash Capability (Function Calling Simulation)") | ||
| model = get_model(capabilities_config={ | ||
| "bash": { | ||
| "enabled": True, | ||
| "safe_mode": True | ||
| } | ||
| }) | ||
| messages = [HumanMessage(content="Please list the files in the current directory using bash.")] | ||
| try: | ||
| print(f"Enabled capabilities: {model.list_capabilities()}") | ||
| print("Sending request with bash tool...") | ||
| response = await model.ainvoke(messages) | ||
| print(f"\nResponse type: {type(response).__name__}") | ||
| print(f"Content: {response.content[:200] if response.content else '(none)'}...") | ||
| if response.tool_calls: | ||
| print(f"\nTool calls detected: {len(response.tool_calls)}") | ||
| for i, tc in enumerate(response.tool_calls): | ||
| print(f" [{i}] Tool: {tc['name']}") | ||
| print(f" Args: {tc['args']}") | ||
| print(f" ID: {tc['id']}") | ||
| # Use capability's simulation | ||
| bash_cap = model.get_capability("bash") | ||
| if bash_cap: | ||
| args = response.tool_calls[0]['args'] | ||
| sim_result = bash_cap.execute_sync(**args) | ||
| print(f"\nSimulation result: {sim_result}") | ||
| # Continue conversation with mock result | ||
| mock_result = "file1.txt\nfile2.py\nfolder1/\nREADME.md" | ||
| messages_with_result = [ | ||
| messages[0], | ||
| response, | ||
| ToolMessage(content=mock_result, tool_call_id=response.tool_calls[0]['id']) | ||
| ] | ||
| final_response = await model.ainvoke(messages_with_result) | ||
| print(f"\nFinal response: {final_response.content}") | ||
| return True | ||
| else: | ||
| print("No tool calls - model responded directly") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 5: Runtime Capability Override | ||
| # ============================================================================== | ||
| async def test_runtime_capability_override(): | ||
| """Test runtime capability override.""" | ||
| print_separator("Test 5: Runtime Capability Override") | ||
| # Model with both capabilities enabled | ||
| model = get_model(capabilities_config={ | ||
| "web_search": {"enabled": True, "max_results": 5}, | ||
| "bash": {"enabled": True} | ||
| }) | ||
| messages = [HumanMessage(content="Search for the latest Python version.")] | ||
| try: | ||
| print(f"Base capabilities: {model.list_capabilities()}") | ||
| # Test 1: Invoke with default capabilities | ||
| print("\n--- Invoke with default (web_search enabled) ---") | ||
| response1 = await model.ainvoke(messages) | ||
| print(f"Tool calls: {[tc['name'] for tc in response1.tool_calls] if response1.tool_calls else 'None'}") | ||
| # Test 2: Disable web_search at runtime | ||
| print("\n--- Invoke with web_search disabled at runtime ---") | ||
| response2 = await model.ainvoke( | ||
| messages, | ||
| config={"configurable": {"capabilities": {"web_search": False}}} | ||
| ) | ||
| print(f"Tool calls: {[tc['name'] for tc in response2.tool_calls] if response2.tool_calls else 'None'}") | ||
| print(f"Direct response: {response2.content[:100] if response2.content else '(none)'}...") | ||
| # Test 3: Override search parameters | ||
| print("\n--- Invoke with max_results override ---") | ||
| response3 = await model.ainvoke( | ||
| messages, | ||
| config={"configurable": {"capabilities": {"web_search": {"max_results": 10}}}} | ||
| ) | ||
| print(f"Tool calls: {[tc['name'] for tc in response3.tool_calls] if response3.tool_calls else 'None'}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 6: Multiple Capabilities Combined | ||
| # ============================================================================== | ||
| async def test_multiple_capabilities(): | ||
| """Test multiple capabilities combined.""" | ||
| print_separator("Test 6: Multiple Capabilities Combined") | ||
| model = get_model(capabilities_config={ | ||
| "web_search": {"enabled": True, "max_results": 3}, | ||
| "bash": {"enabled": True, "safe_mode": True} | ||
| }) | ||
| messages = [ | ||
| HumanMessage(content=( | ||
| "I need to do two things:\n" | ||
| "1. Search for the latest Node.js version\n" | ||
| "2. Show me how to check the current Node version using bash" | ||
| )) | ||
| ] | ||
| try: | ||
| print(f"Enabled capabilities: {model.list_capabilities()}") | ||
| print("Sending request with multiple capabilities...") | ||
| response = await model.ainvoke(messages) | ||
| print(f"\nResponse content: {response.content[:300] if response.content else '(none)'}...") | ||
| if response.tool_calls: | ||
| print(f"\nTool calls: {len(response.tool_calls)}") | ||
| for i, tc in enumerate(response.tool_calls): | ||
| print(f" [{i}] {tc['name']}: {tc['args']}") | ||
| else: | ||
| print("No tool calls in response") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 7: Capability Registry | ||
| # ============================================================================== | ||
| async def test_capability_registry(): | ||
| """Test capability registry functionality.""" | ||
| print_separator("Test 7: Capability Registry") | ||
| try: | ||
| registry = CapabilityRegistry() | ||
| # Test registered types | ||
| print(f"Registered capability types: {registry.get_registered_types()}") | ||
| # Create capabilities from config | ||
| config = { | ||
| "web_search": {"enabled": True, "max_results": 5}, | ||
| "bash": {"enabled": True, "safe_mode": True}, | ||
| "disabled_cap": {"enabled": False} | ||
| } | ||
| capabilities = registry.create_from_config(config) | ||
| print(f"\nCreated capabilities: {list(capabilities.keys())}") | ||
| # Get tool schemas | ||
| schemas = registry.get_tool_schemas(capabilities) | ||
| print(f"\nTool schemas: {len(schemas)}") | ||
| for schema in schemas: | ||
| print(f" - {schema['function']['name']}: {schema['function']['description'][:50]}...") | ||
| # Test merge logic | ||
| base = {"enabled": True, "max_results": 5} | ||
| merged_false = registry.merge_capability_config(base, False) | ||
| print(f"\nMerge with False: {merged_false}") | ||
| merged_true = registry.merge_capability_config(base, True) | ||
| print(f"Merge with True: {merged_true}") | ||
| merged_dict = registry.merge_capability_config(base, {"max_results": 10}) | ||
| print(f"Merge with dict: {merged_dict}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Test 8: Stream with Capabilities | ||
| # ============================================================================== | ||
| async def test_astream_with_capabilities(): | ||
| """Test streaming with capabilities enabled.""" | ||
| print_separator("Test 8: Stream with Capabilities") | ||
| model = get_model(capabilities_config={ | ||
| "web_search": {"enabled": True}, | ||
| "bash": {"enabled": True} | ||
| }) | ||
| messages = [HumanMessage(content="What's 2 + 2? Just give me the answer.")] | ||
| try: | ||
| print("Starting stream with capabilities...") | ||
| print("-" * 40) | ||
| full_content = "" | ||
| tool_calls = [] | ||
| async for chunk in model.astream(messages): | ||
| if chunk.content: | ||
| print(chunk.content, end="", flush=True) | ||
| full_content += chunk.content | ||
| if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: | ||
| tool_calls.extend(chunk.tool_call_chunks) | ||
| print("\n" + "-" * 40) | ||
| print(f"Content: {len(full_content)} chars") | ||
| print(f"Tool call chunks: {len(tool_calls)}") | ||
| return True | ||
| except Exception as e: | ||
| print(f"Test failed: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
| return False | ||
| # ============================================================================== | ||
| # Main | ||
| # ============================================================================== | ||
| async def run_all_tests(): | ||
| """Run all tests.""" | ||
| print("\n" + "=" * 70) | ||
| print(" OAIClaudeChatModel Test Suite") | ||
| print(" Configuration:") | ||
| print(f" - Base URL: {CONFIG['base_url']}") | ||
| print(f" - Model: {CONFIG['model']}") | ||
| print(f" - Platform: {platform.system()}") | ||
| print("=" * 70) | ||
| results = {} | ||
| # Test 1: Basic ainvoke | ||
| try: | ||
| results["ainvoke_basic"] = await test_ainvoke_basic() | ||
| except Exception as e: | ||
| print(f"Test ainvoke_basic exception: {e}") | ||
| results["ainvoke_basic"] = False | ||
| # Test 2: Basic astream | ||
| try: | ||
| results["astream_basic"] = await test_astream_basic() | ||
| except Exception as e: | ||
| print(f"Test astream_basic exception: {e}") | ||
| results["astream_basic"] = False | ||
| # Test 3: Web Search Capability | ||
| try: | ||
| results["web_search"] = await test_web_search_capability() | ||
| except Exception as e: | ||
| print(f"Test web_search exception: {e}") | ||
| results["web_search"] = False | ||
| # Test 4: Bash Capability | ||
| try: | ||
| results["bash"] = await test_bash_capability() | ||
| except Exception as e: | ||
| print(f"Test bash exception: {e}") | ||
| results["bash"] = False | ||
| # Test 5: Runtime Override | ||
| try: | ||
| results["runtime_override"] = await test_runtime_capability_override() | ||
| except Exception as e: | ||
| print(f"Test runtime_override exception: {e}") | ||
| results["runtime_override"] = False | ||
| # Test 6: Multiple Capabilities | ||
| try: | ||
| results["multiple_caps"] = await test_multiple_capabilities() | ||
| except Exception as e: | ||
| print(f"Test multiple_caps exception: {e}") | ||
| results["multiple_caps"] = False | ||
| # Test 7: Capability Registry | ||
| try: | ||
| results["registry"] = await test_capability_registry() | ||
| except Exception as e: | ||
| print(f"Test registry exception: {e}") | ||
| results["registry"] = False | ||
| # Test 8: Stream with Capabilities | ||
| try: | ||
| results["astream_caps"] = await test_astream_with_capabilities() | ||
| except Exception as e: | ||
| print(f"Test astream_caps exception: {e}") | ||
| results["astream_caps"] = False | ||
| # Summary | ||
| print_separator("Test Results Summary") | ||
| for test_name, passed in results.items(): | ||
| status = "PASS" if passed else "FAIL" | ||
| print(f" {test_name}: {status}") | ||
| print("\n" + "-" * 40) | ||
| print("Test Coverage:") | ||
| print(" 1. Basic ainvoke/astream without capabilities") | ||
| print(" 2. Web search via Tavily (function calling)") | ||
| print(" 3. Bash command simulation (function calling)") | ||
| print(" 4. Runtime capability toggle and parameter override") | ||
| print(" 5. Multiple capabilities combined") | ||
| print(" 6. Capability registry and merge logic") | ||
| print("-" * 40) | ||
| all_passed = all(results.values()) | ||
| print(f"\nOverall: {'ALL PASSED' if all_passed else 'SOME FAILED'}") | ||
| return 0 if all_passed else 1 | ||
| def main(): | ||
| """Main entry point.""" | ||
| return asyncio.run(run_all_tests()) | ||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
| from .gemini_chat_model import GeminiChatModel | ||
| from .claude_chat_model import ClaudeChatModel | ||
| from .oai_claude_chat_model import OAIClaudeChatModel | ||
| from .init_services import init_load_balancer, get_model_balancer | ||
@@ -9,2 +10,9 @@ from .model_load_balancer import ModelLoadBalancer | ||
| from .tracing_manager import TracingManager, TracingContext | ||
| from .capabilities import ( | ||
| BaseCapability, | ||
| CapabilityConfig, | ||
| CapabilityRegistry, | ||
| TavilySearchCapability, | ||
| BashCapability, | ||
| ) | ||
@@ -14,2 +22,3 @@ __all__ = [ | ||
| "ClaudeChatModel", | ||
| "OAIClaudeChatModel", | ||
| "init_load_balancer", | ||
@@ -24,3 +33,8 @@ "get_model_balancer", | ||
| "TracingManager", | ||
| "TracingContext" | ||
| "TracingContext", | ||
| "BaseCapability", | ||
| "CapabilityConfig", | ||
| "CapabilityRegistry", | ||
| "TavilySearchCapability", | ||
| "BashCapability", | ||
| ] |
@@ -29,10 +29,27 @@ import os | ||
| None. Azure OpenAI Chat Completions API does not have built-in tools like | ||
| web search or code execution. To add these capabilities: | ||
| - Use Azure OpenAI Assistants API with built-in tools | ||
| - Implement external tools via function calling | ||
| - Use third-party services for web search/code execution | ||
| None. Azure OpenAI Chat Completions API does not have built-in web search | ||
| or similar capabilities. | ||
| The `capabilities` configuration is not applicable to this model class. | ||
| Unlike Claude and Gemini which have platform-native web search: | ||
| - Claude: Built-in `web_search_20250305` tool, search executed by Anthropic | ||
| - Gemini: Built-in `GoogleSearch` tool, search executed by Google | ||
| - Azure OpenAI: **No built-in web search in Chat Completions API** | ||
| Microsoft does offer web search through other APIs, but none are compatible | ||
| with the Chat Completions API used here: | ||
| - Responses API `web_search_preview`: Only available on OpenAI direct, NOT on Azure | ||
| - Agents API `bing_grounding`: Requires `azure-ai-agents` SDK, incompatible with LangChain | ||
| - Chat Completions `data_sources`: Only supports Azure AI Search (your own index), not web search | ||
| If web search is needed for Azure models in the future, the feasible approach is | ||
| Function Calling + external search service (e.g. Bing Search API): | ||
| 1. Inject a `web_search` function definition into the request | ||
| 2. When the model returns a tool_call, execute the search externally | ||
| 3. Feed results back to the model for final response generation | ||
| This is fundamentally different from Claude/Gemini where the platform handles | ||
| the search internally in a single API call. | ||
| Revisit this when Microsoft enables `web_search_preview` on Azure OpenAI. | ||
| Ref: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/web-search | ||
| Attributes: | ||
@@ -39,0 +56,0 @@ logger (Optional[logging.Logger]): An optional logger instance. |
@@ -895,6 +895,7 @@ import base64 | ||
| # --- Basic Token Counts --- | ||
| input_tokens = getattr(usage_metadata, "prompt_token_count", 0) | ||
| output_tokens = getattr(usage_metadata, "candidates_token_count", 0) | ||
| thoughts_tokens = getattr(usage_metadata, "thoughts_token_count", 0) | ||
| total_tokens = getattr(usage_metadata, "total_token_count", 0) | ||
| # Note: getattr returns None if attribute exists but is None, so we need to coerce | ||
| input_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0 | ||
| output_tokens = getattr(usage_metadata, "candidates_token_count", 0) or 0 | ||
| thoughts_tokens = getattr(usage_metadata, "thoughts_token_count", 0) or 0 | ||
| total_tokens = getattr(usage_metadata, "total_token_count", 0) or 0 | ||
@@ -901,0 +902,0 @@ # In some cases, total_tokens is not provided, so we calculate it |
@@ -11,2 +11,3 @@ import json | ||
| from .azure_chat_model import TracedAzureChatOpenAI | ||
| from .oai_claude_chat_model import OAIClaudeChatModel | ||
@@ -342,3 +343,8 @@ | ||
| # Custom endpoint for Claude models (e.g., ModelGate, custom proxies) | ||
| # Uses native Anthropic SDK | ||
| return self._instantiate_custom_claude_model(model_config, disable_streaming) | ||
| elif provider == 'custom-openai': | ||
| # Custom OpenAI-compatible endpoint for Claude models | ||
| # Uses OpenAI-compatible API with capability support (web_search, bash, etc.) | ||
| return self._instantiate_oai_claude_model(model_config, disable_streaming) | ||
| else: | ||
@@ -398,2 +404,54 @@ self.logger.error(f"Unsupported provider: {provider}") | ||
| def _instantiate_oai_claude_model(self, model_config: Dict, disable_streaming: bool = False): | ||
| """ | ||
| Instantiate OAIClaudeChatModel for OpenAI-compatible endpoints. | ||
| This uses the OpenAI-compatible API (via langchain_openai.ChatOpenAI) to | ||
| access Claude models through proxy services like ModelGate, LiteLLM, etc. | ||
| Supports custom capabilities (web_search, bash) via function calling. | ||
| Args: | ||
| model_config: Model configuration dict with required fields: | ||
| - deployment_name: Full deployment name (e.g., "claude-opus-4-5@mg") | ||
| - base_url: OpenAI-compatible endpoint URL (e.g., "https://mg.aid.pub/v1") | ||
| - api_key: API key for the endpoint | ||
| disable_streaming: Whether to disable streaming | ||
| Returns: | ||
| Configured OAIClaudeChatModel instance | ||
| """ | ||
| deployment_name = model_config['deployment_name'] | ||
| # Extract model name from deployment_name (e.g., "claude-opus-4-5@mg" -> "claude-opus-4-5@mg") | ||
| # For OpenAI-compatible endpoints, we use the full deployment_name as model | ||
| model_name = deployment_name | ||
| # Extract capabilities from config (no default version mapping for custom capabilities) | ||
| capabilities = model_config.get('capabilities', {}) | ||
| # Log enabled capabilities | ||
| enabled_caps = [k for k, v in capabilities.items() | ||
| if isinstance(v, dict) and v.get("enabled") | ||
| or (isinstance(v, bool) and v)] | ||
| if enabled_caps: | ||
| self.logger.debug(f"OAI Claude model {model_name} capabilities: {enabled_caps}") | ||
| # Build kwargs for OAIClaudeChatModel | ||
| kwargs = { | ||
| 'model': model_name, | ||
| 'api_key': model_config['api_key'], | ||
| 'base_url': model_config['base_url'], | ||
| 'capabilities_config': capabilities, | ||
| } | ||
| # Add optional parameters | ||
| if 'temperature' in model_config: | ||
| kwargs['temperature'] = model_config['temperature'] | ||
| if 'max_tokens' in model_config: | ||
| kwargs['max_tokens'] = model_config['max_tokens'] | ||
| if disable_streaming: | ||
| kwargs['streaming'] = False | ||
| return OAIClaudeChatModel(**kwargs) | ||
| def _initialize_state(self): | ||
@@ -400,0 +458,0 @@ self.active_models = [] |
+2
-1
| Metadata-Version: 2.1 | ||
| Name: crewplus | ||
| Version: 0.2.97 | ||
| Version: 0.2.99 | ||
| Summary: Base services for CrewPlus AI applications | ||
@@ -18,2 +18,3 @@ Author-Email: Tim Liu <tim@opsmateai.com> | ||
| Requires-Dist: langfuse<4.0.0,>=3.1.3 | ||
| Requires-Dist: pymilvus<=2.6.7,>=2.5.7 | ||
| Description-Content-Type: text/markdown | ||
@@ -20,0 +21,0 @@ |
+2
-1
@@ -9,3 +9,3 @@ [build-system] | ||
| name = "crewplus" | ||
| version = "0.2.97" | ||
| version = "0.2.99" | ||
| description = "Base services for CrewPlus AI applications" | ||
@@ -24,2 +24,3 @@ authors = [ | ||
| "langfuse (>=3.1.3,<4.0.0)", | ||
| "pymilvus (>=2.5.7,<=2.6.7)", | ||
| ] | ||
@@ -26,0 +27,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
508061
28.53%46
27.78%9376
39.38%