crewplus
Advanced tools
| """ | ||
| Pytest configuration for crewplus-base tests. | ||
| """ | ||
| import pytest | ||
| def pytest_configure(config): | ||
| """Configure pytest with custom markers.""" | ||
| config.addinivalue_line( | ||
| "markers", "integration: marks tests as integration tests (may call external APIs)" | ||
| ) |
+170
| # CrewPlus Tests | ||
| This directory contains tests for the crewplus-base package, with full Langfuse tracing support. | ||
| ## Test Structure | ||
| - `test_gemini_bind_tools.py` - Tests for GeminiChatModel's bind_tools functionality | ||
| - Tool conversion tests | ||
| - Tool binding tests | ||
| - Integration tests with actual API calls | ||
| - Backward compatibility tests | ||
| - Edge case tests | ||
| ## Running Tests | ||
| ### Prerequisites | ||
| 1. Install test dependencies: | ||
| ```bash | ||
| pip install pytest | ||
| ``` | ||
| 2. Ensure the config file exists: | ||
| ```bash | ||
| # Config file should be at: ../_config/models_config.json | ||
| ``` | ||
| 3. Environment variables (optional - defaults are provided): | ||
| ```bash | ||
| # Langfuse tracing (automatically configured with defaults) | ||
| export LANGFUSE_PUBLIC_KEY="your-public-key" | ||
| export LANGFUSE_SECRET_KEY="your-secret-key" | ||
| export LANGFUSE_HOST="your-langfuse-host" | ||
| ``` | ||
| ### Run All Tests | ||
| ```bash | ||
| # From the tests directory | ||
| pytest test_gemini_bind_tools.py -v | ||
| # Or from the package root | ||
| pytest tests/ -v | ||
| ``` | ||
| ### Run Specific Test Classes | ||
| ```bash | ||
| # Test only tool conversion | ||
| pytest test_gemini_bind_tools.py::TestToolConversion -v | ||
| # Test only bind_tools method | ||
| pytest test_gemini_bind_tools.py::TestBindTools -v | ||
| # Test only backward compatibility | ||
| pytest test_gemini_bind_tools.py::TestBackwardCompatibility -v | ||
| ``` | ||
| ### Run Without Integration Tests | ||
| Integration tests make actual API calls and may incur costs. To skip them: | ||
| ```bash | ||
| pytest test_gemini_bind_tools.py -v -m "not integration" | ||
| ``` | ||
| ### Run Only Integration Tests | ||
| ```bash | ||
| pytest test_gemini_bind_tools.py -v -m integration | ||
| ``` | ||
| ### Verbose Output with Full Traceback | ||
| ```bash | ||
| pytest test_gemini_bind_tools.py -vv --tb=long | ||
| ``` | ||
| ## Langfuse Tracing | ||
| Langfuse tracing is automatically enabled for all tests that use the `model_balancer` fixture. | ||
| ### Default Configuration | ||
| The tests use the following default Langfuse configuration: | ||
| - **Public Key**: `pk-lf-874857f5-6bad-4141-96eb-cf36f70009e6` | ||
| - **Secret Key**: `sk-lf-3fe02b88-be46-4394-8da0-9ec409660de1` | ||
| - **Host**: `https://langfuse-test.crewplus.ai` | ||
| You can override these by setting environment variables before running tests. | ||
| ### Viewing Traces | ||
| 1. Go to your Langfuse dashboard: https://langfuse-test.crewplus.ai | ||
| 2. Filter by test name or model name to find specific test runs | ||
| 3. All integration tests (marked with `@pytest.mark.integration`) will have tracing enabled | ||
| ### Disabling Tracing | ||
| Unit tests (without `@pytest.mark.integration`) automatically have tracing disabled for faster execution. | ||
| ## Test Coverage | ||
| The test suite covers: | ||
| 1. **Tool Conversion** (`TestToolConversion`) | ||
| - Converting LangChain tools to Gemini FunctionDeclarations | ||
| - Type mapping (string, number, boolean, etc.) | ||
| - Handling invalid tools | ||
| - Parameter schema extraction | ||
| 2. **Tool Binding** (`TestBindTools`) | ||
| - Binding single and multiple tools | ||
| - Empty tool lists | ||
| - Additional kwargs handling | ||
| 3. **Tool Invocation** (`TestToolInvocation`) - Integration Tests | ||
| - Simple calculations with tools | ||
| - Tool call structure validation | ||
| - Complete tool execution loop | ||
| - Multiple tool selection | ||
| 4. **Backward Compatibility** (`TestBackwardCompatibility`) | ||
| - Model works without tools | ||
| - Generation config without tools | ||
| - Generation config with tools | ||
| 5. **Edge Cases** (`TestEdgeCases`) | ||
| - Rebinding tools | ||
| - Streaming with tools | ||
| ## Test Fixtures | ||
| - `langfuse_config` - Sets up Langfuse environment variables (module-scoped) | ||
| - `model_balancer` - Initialized model load balancer (module-scoped) | ||
| - `gemini_model` - GeminiChatModel instance from balancer (parameterized for both Google AI and Vertex AI) | ||
| - `calculator_tool` - Sample calculator tool for testing | ||
| - `weather_tool` - Sample weather tool for testing | ||
| ## Parameterized Tests | ||
| The `gemini_model` fixture is parameterized to test both: | ||
| - **Google AI**: `gemini-2.5-flash` | ||
| - **Vertex AI**: `gemini-2.5-flash@us-central1` | ||
| This ensures the bind_tools feature works for both deployment types. | ||
| ## Bug Fixes | ||
| ### Vertex AI Tool Binding Fix | ||
| The initial implementation had a bug where Vertex AI would fail with: | ||
| ``` | ||
| 400 INVALID_ARGUMENT: tools[0].tool_type: required one_of 'tool_type' must have one initialized field | ||
| ``` | ||
| **Root Cause**: FunctionDeclarations were being passed directly in the config dict, causing them to serialize as empty objects `{}`. | ||
| **Fix**: | ||
| 1. Changed `_prepare_generation_config` to return `types.GenerateContentConfig` object instead of dict | ||
| 2. Wrapped FunctionDeclarations in `types.Tool(function_declarations=[...])` before adding to config | ||
| This ensures proper serialization for both Google AI and Vertex AI endpoints. | ||
| ## Notes | ||
| - Tests use the same `models_config.json` as production code | ||
| - Integration tests are marked and can be skipped to avoid API costs | ||
| - Tracing is automatically enabled for integration tests | ||
| - All test output includes Langfuse trace information when applicable |
| """ | ||
| Tests for GeminiChatModel's bind_tools functionality. | ||
| This test suite validates the tool binding feature for Gemini models, | ||
| including tool conversion, binding, and invocation. | ||
| Langfuse tracing is enabled for integration tests to track performance and usage. | ||
| """ | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
| import pytest | ||
| from typing import Optional | ||
| from pydantic import BaseModel, Field | ||
| from google.genai import types | ||
| from langchain_core.tools import BaseTool | ||
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | ||
| # Add project root to path | ||
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | ||
| sys.path.insert(0, str(PROJECT_ROOT)) | ||
| from crewplus.services import init_load_balancer, get_model_balancer | ||
| # ============================================================================= | ||
| # Test Tools Definition | ||
| # ============================================================================= | ||
| class CalculatorInput(BaseModel): | ||
| """Input schema for calculator tool.""" | ||
| operation: str = Field( | ||
| description="The operation to perform: 'add', 'subtract', 'multiply', or 'divide'" | ||
| ) | ||
| a: float = Field(description="The first number") | ||
| b: float = Field(description="The second number") | ||
| class CalculatorTool(BaseTool): | ||
| """A simple calculator tool for basic arithmetic operations.""" | ||
| name: str = "calculator" | ||
| description: str = ( | ||
| "Performs basic arithmetic operations (add, subtract, multiply, divide). " | ||
| "Use this tool when you need to calculate numerical results. " | ||
| "Provide the operation type and two numbers." | ||
| ) | ||
| args_schema: type[BaseModel] = CalculatorInput | ||
| def _run(self, operation: str, a: float, b: float) -> str: | ||
| """Execute the calculator operation.""" | ||
| try: | ||
| if operation == "add": | ||
| result = a + b | ||
| return f"The result of {a} + {b} is {result}" | ||
| elif operation == "subtract": | ||
| result = a - b | ||
| return f"The result of {a} - {b} is {result}" | ||
| elif operation == "multiply": | ||
| result = a * b | ||
| return f"The result of {a} × {b} is {result}" | ||
| elif operation == "divide": | ||
| if b == 0: | ||
| return "Error: Cannot divide by zero" | ||
| result = a / b | ||
| return f"The result of {a} ÷ {b} is {result}" | ||
| else: | ||
| return f"Error: Unknown operation '{operation}'" | ||
| except Exception as e: | ||
| return f"Error performing calculation: {str(e)}" | ||
| async def _arun(self, operation: str, a: float, b: float) -> str: | ||
| """Async version of _run.""" | ||
| return self._run(operation, a, b) | ||
| class WeatherInput(BaseModel): | ||
| """Input schema for weather tool.""" | ||
| location: str = Field(description="The city or location to get weather for") | ||
| unit: str = Field(default="celsius", description="Temperature unit: 'celsius' or 'fahrenheit'") | ||
| class WeatherTool(BaseTool): | ||
| """A mock weather tool.""" | ||
| name: str = "get_weather" | ||
| description: str = ( | ||
| "Gets the current weather for a given location. " | ||
| "Returns temperature and conditions." | ||
| ) | ||
| args_schema: type[BaseModel] = WeatherInput | ||
| def _run(self, location: str, unit: str = "celsius") -> str: | ||
| """Execute the weather lookup.""" | ||
| # Mock response | ||
| temp = 22 if unit == "celsius" else 72 | ||
| return f"The weather in {location} is sunny with a temperature of {temp}°{unit[0].upper()}." | ||
| async def _arun(self, location: str, unit: str = "celsius") -> str: | ||
| """Async version of _run.""" | ||
| return self._run(location, unit) | ||
| # ============================================================================= | ||
| # Fixtures | ||
| # ============================================================================= | ||
| @pytest.fixture(scope="module") | ||
| def langfuse_config(): | ||
| """Configure Langfuse environment variables for tracing.""" | ||
| # Set Langfuse configuration | ||
| os.environ["LANGFUSE_PUBLIC_KEY"] = os.getenv( | ||
| "LANGFUSE_PUBLIC_KEY", | ||
| "pk-lf-874857f5-6bad-4141-96eb-cf36f70009e6" | ||
| ) | ||
| os.environ["LANGFUSE_SECRET_KEY"] = os.getenv( | ||
| "LANGFUSE_SECRET_KEY", | ||
| "sk-lf-3fe02b88-be46-4394-8da0-9ec409660de1" | ||
| ) | ||
| os.environ["LANGFUSE_HOST"] = os.getenv( | ||
| "LANGFUSE_HOST", | ||
| "https://langfuse-test.crewplus.ai" | ||
| ) | ||
| yield { | ||
| "public_key": os.environ["LANGFUSE_PUBLIC_KEY"], | ||
| "secret_key": os.environ["LANGFUSE_SECRET_KEY"], | ||
| "host": os.environ["LANGFUSE_HOST"] | ||
| } | ||
| @pytest.fixture(scope="module") | ||
| def model_balancer(langfuse_config): | ||
| """Initialize and return the model load balancer.""" | ||
| config_path = PROJECT_ROOT / "_config" / "models_config.json" | ||
| if not config_path.exists(): | ||
| pytest.skip(f"Config file not found: {config_path}") | ||
| init_load_balancer(str(config_path)) | ||
| return get_model_balancer() | ||
| @pytest.fixture(params=[ | ||
| # "gemini-2.5-flash", # Google AI | ||
| "gemini-2.5-flash@us-central1", # Vertex AI | ||
| ]) | ||
| def gemini_model(request, model_balancer): | ||
| """Create a GeminiChatModel instance from model balancer for testing.""" | ||
| deployment_name = request.param | ||
| try: | ||
| model = model_balancer.get_model(deployment_name=deployment_name) | ||
| # Enable tracing for integration tests, disable for unit tests | ||
| # Integration tests are marked with @pytest.mark.integration | ||
| if hasattr(model, 'enable_tracing'): | ||
| # Check if we're in an integration test | ||
| if hasattr(request, 'node') and request.node.get_closest_marker('integration'): | ||
| model.enable_tracing = True | ||
| else: | ||
| model.enable_tracing = False | ||
| return model | ||
| except Exception as e: | ||
| pytest.skip(f"Could not get model '{deployment_name}': {e}") | ||
| @pytest.fixture | ||
| def calculator_tool(): | ||
| """Create a calculator tool instance.""" | ||
| return CalculatorTool() | ||
| @pytest.fixture | ||
| def weather_tool(): | ||
| """Create a weather tool instance.""" | ||
| return WeatherTool() | ||
| # ============================================================================= | ||
| # Test Tool Conversion | ||
| # ============================================================================= | ||
| class TestToolConversion: | ||
| """Tests for the _convert_langchain_tool_to_gemini_declaration method.""" | ||
| def test_convert_calculator_tool_to_declaration(self, gemini_model, calculator_tool): | ||
| """Test converting a LangChain tool to Gemini FunctionDeclaration.""" | ||
| func_decl = gemini_model._convert_langchain_tool_to_gemini_declaration(calculator_tool) | ||
| assert func_decl is not None | ||
| assert isinstance(func_decl, types.FunctionDeclaration) | ||
| assert func_decl.name == "calculator" | ||
| assert "arithmetic" in func_decl.description.lower() | ||
| # Check parameters schema | ||
| assert func_decl.parameters.type == types.Type.OBJECT | ||
| assert "operation" in func_decl.parameters.properties | ||
| assert "a" in func_decl.parameters.properties | ||
| assert "b" in func_decl.parameters.properties | ||
| assert set(func_decl.parameters.required) == {"operation", "a", "b"} | ||
| def test_convert_weather_tool_to_declaration(self, gemini_model, weather_tool): | ||
| """Test converting a weather tool to Gemini FunctionDeclaration.""" | ||
| func_decl = gemini_model._convert_langchain_tool_to_gemini_declaration(weather_tool) | ||
| assert func_decl is not None | ||
| assert isinstance(func_decl, types.FunctionDeclaration) | ||
| assert func_decl.name == "get_weather" | ||
| assert "weather" in func_decl.description.lower() | ||
| # Check parameters | ||
| assert "location" in func_decl.parameters.properties | ||
| assert "unit" in func_decl.parameters.properties | ||
| # Only location is required, unit has a default | ||
| assert "location" in func_decl.parameters.required | ||
| def test_convert_invalid_tool(self, gemini_model): | ||
| """Test that invalid tools return None.""" | ||
| # Tool without name | ||
| class InvalidTool: | ||
| description = "Test" | ||
| result = gemini_model._convert_langchain_tool_to_gemini_declaration(InvalidTool()) | ||
| assert result is None | ||
| def test_type_mapping(self, gemini_model, calculator_tool): | ||
| """Test that JSON schema types are correctly mapped to Gemini types.""" | ||
| func_decl = gemini_model._convert_langchain_tool_to_gemini_declaration(calculator_tool) | ||
| # operation should be STRING type | ||
| assert func_decl.parameters.properties["operation"].type == types.Type.STRING | ||
| # a and b should be NUMBER type (float in Pydantic) | ||
| assert func_decl.parameters.properties["a"].type == types.Type.NUMBER | ||
| assert func_decl.parameters.properties["b"].type == types.Type.NUMBER | ||
| # ============================================================================= | ||
| # Test bind_tools Method | ||
| # ============================================================================= | ||
| class TestBindTools: | ||
| """Tests for the bind_tools method.""" | ||
| def test_bind_single_tool(self, gemini_model, calculator_tool): | ||
| """Test binding a single tool to the model.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool]) | ||
| # Should return a RunnableBinding | ||
| assert model_with_tools is not None | ||
| assert hasattr(model_with_tools, "invoke") | ||
| # Check that tools were bound (they're in kwargs) | ||
| assert hasattr(model_with_tools, "kwargs") | ||
| assert "tools" in model_with_tools.kwargs | ||
| assert len(model_with_tools.kwargs["tools"]) == 1 | ||
| assert isinstance(model_with_tools.kwargs["tools"][0], types.FunctionDeclaration) | ||
| def test_bind_multiple_tools(self, gemini_model, calculator_tool, weather_tool): | ||
| """Test binding multiple tools to the model.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool, weather_tool]) | ||
| # Check that both tools were bound | ||
| assert len(model_with_tools.kwargs["tools"]) == 2 | ||
| # Check tool names | ||
| tool_names = {t.name for t in model_with_tools.kwargs["tools"]} | ||
| assert tool_names == {"calculator", "get_weather"} | ||
| def test_bind_empty_list(self, gemini_model): | ||
| """Test binding an empty list of tools.""" | ||
| model_with_tools = gemini_model.bind_tools([]) | ||
| # Should still work, just with no tools | ||
| assert "tools" in model_with_tools.kwargs | ||
| assert model_with_tools.kwargs["tools"] == [] | ||
| def test_bind_with_additional_kwargs(self, gemini_model, calculator_tool): | ||
| """Test binding tools with additional kwargs.""" | ||
| model_with_tools = gemini_model.bind_tools( | ||
| [calculator_tool], | ||
| tool_config={"function_calling_config": {"mode": "AUTO"}} | ||
| ) | ||
| # Tools should be bound | ||
| assert len(model_with_tools.kwargs["tools"]) == 1 | ||
| # Additional kwargs should also be present | ||
| assert "tool_config" in model_with_tools.kwargs | ||
| # ============================================================================= | ||
| # Test Integration with Model Invocation | ||
| # ============================================================================= | ||
| @pytest.mark.integration | ||
| class TestToolInvocation: | ||
| """Tests for actual tool invocation with the model.""" | ||
| def test_simple_calculation_with_tools(self, gemini_model, calculator_tool): | ||
| """Test a simple calculation using the bound tool.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool]) | ||
| query = "What is 15 + 27?" | ||
| response = model_with_tools.invoke(query) | ||
| # Check response | ||
| assert response is not None | ||
| assert isinstance(response, AIMessage) | ||
| # The model should either: | ||
| # 1. Use the tool (have tool_calls) | ||
| # 2. Answer directly with knowledge | ||
| # We'll just check it responded | ||
| assert response.content is not None or ( | ||
| hasattr(response, 'tool_calls') and response.tool_calls | ||
| ) | ||
| def test_tool_calls_structure(self, gemini_model, calculator_tool): | ||
| """Test that tool calls have the correct structure.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool]) | ||
| query = "Calculate 8 times 7 using the calculator" | ||
| response = model_with_tools.invoke(query) | ||
| # If tool was called, check structure | ||
| if hasattr(response, 'tool_calls') and response.tool_calls: | ||
| tool_call = response.tool_calls[0] | ||
| assert 'name' in tool_call | ||
| assert 'args' in tool_call | ||
| assert 'id' in tool_call | ||
| # Check args structure for calculator | ||
| args = tool_call['args'] | ||
| if 'operation' in args: # If tool was called | ||
| assert args['operation'] in ['add', 'subtract', 'multiply', 'divide'] | ||
| assert 'a' in args | ||
| assert 'b' in args | ||
| def test_complete_tool_execution_loop(self, gemini_model, calculator_tool): | ||
| """Test a complete tool execution loop: request -> tool call -> execution -> final answer.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool]) | ||
| # Step 1: Initial query | ||
| query = "What is 15 + 27?" | ||
| response = model_with_tools.invoke(query) | ||
| print(f"\nStep 1: Initial Response") | ||
| print(f" Content: {response.content}") | ||
| if hasattr(response, 'tool_calls'): | ||
| print(f" Tool Calls: {response.tool_calls}") | ||
| # Step 2: Execute tools and build message history | ||
| if hasattr(response, 'tool_calls') and response.tool_calls: | ||
| messages = [ | ||
| HumanMessage(content=query), | ||
| response # The AIMessage with tool_calls | ||
| ] | ||
| # Execute each tool call | ||
| for tool_call in response.tool_calls: | ||
| tool_name = tool_call["name"] | ||
| tool_args = tool_call["args"] | ||
| tool_id = tool_call["id"] | ||
| print(f"\nStep 2: Executing tool '{tool_name}'") | ||
| print(f" Arguments: {tool_args}") | ||
| # Execute the tool | ||
| if tool_name == "calculator": | ||
| tool_result = calculator_tool._run(**tool_args) | ||
| print(f" Result: {tool_result}") | ||
| # Add tool result as a ToolMessage | ||
| messages.append( | ||
| ToolMessage( | ||
| content=tool_result, | ||
| tool_call_id=tool_id | ||
| ) | ||
| ) | ||
| # Step 3: Get final response | ||
| print("\nStep 3: Getting final response...") | ||
| final_response = model_with_tools.invoke(messages) | ||
| print(f"\nFinal Answer: {final_response.content}") | ||
| assert final_response is not None | ||
| assert final_response.content is not None | ||
| # The final answer should mention the result (42) | ||
| assert "42" in final_response.content | ||
| def test_multiple_tools_selection(self, gemini_model, calculator_tool, weather_tool): | ||
| """Test that the model can choose between multiple tools.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool, weather_tool]) | ||
| # Ask a weather question | ||
| query = "What's the weather like in Tokyo?" | ||
| response = model_with_tools.invoke(query) | ||
| assert response is not None | ||
| # If tool was called, it should be the weather tool | ||
| if hasattr(response, 'tool_calls') and response.tool_calls: | ||
| assert any(tc['name'] == 'get_weather' for tc in response.tool_calls) | ||
| # ============================================================================= | ||
| # Test Backward Compatibility | ||
| # ============================================================================= | ||
| class TestBackwardCompatibility: | ||
| """Tests to ensure backward compatibility.""" | ||
| def test_model_without_tools_still_works(self, gemini_model): | ||
| """Test that the model still works without binding tools.""" | ||
| # This should work exactly as before | ||
| query = "Hello! How are you?" | ||
| response = gemini_model.invoke(query) | ||
| assert response is not None | ||
| assert isinstance(response, AIMessage) | ||
| assert response.content is not None | ||
| def test_prepare_generation_config_without_tools(self, gemini_model): | ||
| """Test that _prepare_generation_config works without tools.""" | ||
| messages = [HumanMessage(content="Test")] | ||
| config = gemini_model._prepare_generation_config(messages, stop=None, tools=None) | ||
| assert isinstance(config, types.GenerateContentConfig) | ||
| assert not hasattr(config, 'tools') or config.tools is None | ||
| def test_prepare_generation_config_with_tools(self, gemini_model, calculator_tool): | ||
| """Test that _prepare_generation_config works with tools.""" | ||
| messages = [HumanMessage(content="Test")] | ||
| func_decl = gemini_model._convert_langchain_tool_to_gemini_declaration(calculator_tool) | ||
| config = gemini_model._prepare_generation_config( | ||
| messages, | ||
| stop=None, | ||
| tools=[func_decl] | ||
| ) | ||
| assert isinstance(config, types.GenerateContentConfig) | ||
| assert config.tools is not None | ||
| assert len(config.tools) == 1 | ||
| assert isinstance(config.tools[0], types.Tool) | ||
| # Check that the tool contains our function declaration | ||
| assert len(config.tools[0].function_declarations) == 1 | ||
| assert config.tools[0].function_declarations[0].name == "calculator" | ||
| # ============================================================================= | ||
| # Test Edge Cases | ||
| # ============================================================================= | ||
| class TestEdgeCases: | ||
| """Tests for edge cases and error handling.""" | ||
| def test_rebinding_tools(self, gemini_model, calculator_tool, weather_tool): | ||
| """Test binding tools multiple times.""" | ||
| # First binding | ||
| model_with_calc = gemini_model.bind_tools([calculator_tool]) | ||
| assert len(model_with_calc.kwargs["tools"]) == 1 | ||
| # Second binding (should replace, not append) | ||
| model_with_weather = model_with_calc.bind_tools([weather_tool]) | ||
| assert len(model_with_weather.kwargs["tools"]) == 1 | ||
| assert model_with_weather.kwargs["tools"][0].name == "get_weather" | ||
| @pytest.mark.integration | ||
| def test_streaming_with_tools(self, gemini_model, calculator_tool): | ||
| """Test that streaming works with bound tools.""" | ||
| model_with_tools = gemini_model.bind_tools([calculator_tool]) | ||
| query = "What is 5 + 3?" | ||
| chunks = list(model_with_tools.stream(query)) | ||
| # Should receive chunks | ||
| assert len(chunks) > 0 | ||
| # At least one chunk should have content or tool calls | ||
| assert any( | ||
| chunk.content or (hasattr(chunk, 'tool_calls') and chunk.tool_calls) | ||
| for chunk in chunks | ||
| ) | ||
| # ============================================================================= | ||
| # Run Tests | ||
| # ============================================================================= | ||
| if __name__ == "__main__": | ||
| # Run with: python test_gemini_bind_tools.py | ||
| # Or: pytest test_gemini_bind_tools.py -v | ||
| pytest.main([__file__, "-v", "--tb=short", "-m", "not integration"]) |
@@ -17,2 +17,3 @@ import os | ||
| SystemMessage, | ||
| ToolMessage, | ||
| ) | ||
@@ -332,2 +333,126 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||
| def _convert_langchain_tool_to_gemini_declaration(self, tool: Any) -> Optional[types.FunctionDeclaration]: | ||
| """ | ||
| Converts a single LangChain tool to Gemini's FunctionDeclaration. | ||
| Args: | ||
| tool: A LangChain tool (BaseTool instance or tool-like object) | ||
| Returns: | ||
| A FunctionDeclaration object for Gemini API, or None if conversion fails | ||
| """ | ||
| # Extract tool name and description | ||
| if not (hasattr(tool, 'name') and hasattr(tool, 'description')): | ||
| self.logger.warning(f"Tool missing name or description: {type(tool)}") | ||
| return None | ||
| tool_name = tool.name | ||
| tool_description = tool.description or "" | ||
| # Extract parameters schema from the tool | ||
| params_dict = {"type": "object", "properties": {}, "required": []} | ||
| if hasattr(tool, 'args_schema') and tool.args_schema: | ||
| try: | ||
| # Get JSON schema from Pydantic model | ||
| schema = tool.args_schema.model_json_schema() | ||
| params_dict = { | ||
| "type": "object", | ||
| "properties": schema.get("properties", {}), | ||
| "required": schema.get("required", []) | ||
| } | ||
| except Exception as e: | ||
| self.logger.warning(f"Failed to extract schema for tool {tool_name}: {e}") | ||
| elif hasattr(tool, 'args') and isinstance(tool.args, dict): | ||
| # Fallback to args dict if available | ||
| params_dict = { | ||
| "type": "object", | ||
| "properties": tool.args, | ||
| "required": [] | ||
| } | ||
| # Convert JSON schema properties to Gemini Schema objects | ||
| try: | ||
| properties = {} | ||
| for prop_name, prop_schema in params_dict.get("properties", {}).items(): | ||
| # Map JSON schema types to Gemini types | ||
| json_type = prop_schema.get("type", "string").lower() | ||
| type_mapping = { | ||
| "string": types.Type.STRING, | ||
| "integer": types.Type.INTEGER, | ||
| "number": types.Type.NUMBER, | ||
| "boolean": types.Type.BOOLEAN, | ||
| "object": types.Type.OBJECT, | ||
| "array": types.Type.ARRAY, | ||
| } | ||
| gemini_type = type_mapping.get(json_type, types.Type.STRING) | ||
| properties[prop_name] = types.Schema( | ||
| type=gemini_type, | ||
| description=prop_schema.get("description", ""), | ||
| ) | ||
| # Create parameters schema | ||
| parameters_schema = types.Schema( | ||
| type=types.Type.OBJECT, | ||
| properties=properties, | ||
| required=params_dict.get("required", []) | ||
| ) | ||
| # Create and return FunctionDeclaration | ||
| return types.FunctionDeclaration( | ||
| name=tool_name, | ||
| description=tool_description, | ||
| parameters=parameters_schema | ||
| ) | ||
| except Exception as e: | ||
| self.logger.error(f"Error converting tool '{tool_name}' to FunctionDeclaration: {e}", exc_info=True) | ||
| return None | ||
| def bind_tools( | ||
| self, | ||
| tools: List, | ||
| **kwargs: Any, | ||
| ) -> "GeminiChatModel": | ||
| """ | ||
| Bind tools to this model, returning a new Runnable with tools configured. | ||
| This method converts LangChain tools to Gemini's FunctionDeclaration format and uses | ||
| the parent class's bind() method to attach them to the model. The tools will be | ||
| passed to the Gemini API as function declarations. | ||
| Args: | ||
| tools: A sequence of LangChain tools to bind to the model | ||
| **kwargs: Additional keyword arguments (e.g., tool_config) | ||
| Returns: | ||
| A new Runnable wrapping this model with tools bound. When invoke() or | ||
| stream() is called on the returned Runnable, the tools will automatically | ||
| be included in the API request. | ||
| Example: | ||
| >>> model = GeminiChatModel(model_name="gemini-2.0-flash") | ||
| >>> tools = [my_search_tool, my_calculator_tool] | ||
| >>> model_with_tools = model.bind_tools(tools) | ||
| >>> # Now invoke() will automatically pass tools to the API | ||
| >>> response = model_with_tools.invoke("What's the weather?") | ||
| Reference: | ||
| https://ai.google.dev/gemini-api/docs/function-calling | ||
| """ | ||
| # Convert each tool to Gemini FunctionDeclaration | ||
| function_declarations = [] | ||
| if tools: | ||
| for tool in tools: | ||
| func_decl = self._convert_langchain_tool_to_gemini_declaration(tool) | ||
| if func_decl: | ||
| function_declarations.append(func_decl) | ||
| self.logger.info(f"Binding {len(function_declarations)} tools to model") | ||
| # Use the parent bind() method to create a new RunnableBinding | ||
| # that will inject 'tools' parameter into every invocation | ||
| return super().bind(tools=function_declarations, **kwargs) | ||
| def invoke(self, input, config=None, **kwargs): | ||
@@ -397,6 +522,39 @@ """Override invoke to add tracing callbacks automatically.""" | ||
| for msg in chat_messages: | ||
| # Handle ToolMessage specially - it represents a function response | ||
| if isinstance(msg, ToolMessage): | ||
| # Extract the function name from tool_call_id (format: "call_<function_name>") | ||
| tool_call_id = msg.tool_call_id | ||
| # Try to extract function name from the ID | ||
| if tool_call_id and tool_call_id.startswith("call_"): | ||
| function_name = tool_call_id[5:] # Remove "call_" prefix | ||
| else: | ||
| function_name = tool_call_id or "unknown" | ||
| # Create a function response part | ||
| function_response_part = types.Part( | ||
| function_response=types.FunctionResponse( | ||
| name=function_name, | ||
| response={"result": msg.content} | ||
| ) | ||
| ) | ||
| # Function responses have role "user" in Gemini | ||
| genai_contents.append(types.Content(parts=[function_response_part], role="user")) | ||
| continue | ||
| role = "model" if isinstance(msg, AIMessage) else "user" | ||
| parts = [] | ||
| # Process each part and ensure proper typing | ||
| # Handle AIMessage with tool_calls - add function_call parts | ||
| if isinstance(msg, AIMessage) and hasattr(msg, 'tool_calls') and msg.tool_calls: | ||
| for tool_call in msg.tool_calls: | ||
| # Create a function call part | ||
| function_call_part = types.Part( | ||
| function_call=types.FunctionCall( | ||
| name=tool_call['name'], | ||
| args=tool_call['args'] | ||
| ) | ||
| ) | ||
| parts.append(function_call_part) | ||
| # Process content parts | ||
| for part in self._parse_message_content(msg.content, is_simple=False): | ||
@@ -410,3 +568,3 @@ if isinstance(part, types.File): | ||
| self.logger.warning(f"Unexpected part type: {type(part)}") | ||
| if parts: | ||
@@ -584,15 +742,23 @@ genai_contents.append(types.Content(parts=parts, role=role)) | ||
| def _prepare_generation_config( | ||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | ||
| ) -> Dict[str, Any]: | ||
| """Prepares the generation configuration, including system instructions.""" | ||
| # Base config from model parameters | ||
| config = { | ||
| "temperature": self.temperature, | ||
| "max_output_tokens": self.max_tokens, | ||
| "top_p": self.top_p, | ||
| "top_k": self.top_k, | ||
| } | ||
| self, | ||
| messages: List[BaseMessage], | ||
| stop: Optional[List[str]] = None, | ||
| tools: Optional[List[types.FunctionDeclaration]] = None | ||
| ) -> types.GenerateContentConfig: | ||
| """Prepares the generation configuration, including system instructions and tools.""" | ||
| # Base config parameters | ||
| config_params = {} | ||
| # Add generation parameters | ||
| if self.temperature is not None: | ||
| config_params["temperature"] = self.temperature | ||
| if self.max_tokens is not None: | ||
| config_params["max_output_tokens"] = self.max_tokens | ||
| if self.top_p is not None: | ||
| config_params["top_p"] = self.top_p | ||
| if self.top_k is not None: | ||
| config_params["top_k"] = self.top_k | ||
| if stop: | ||
| config["stop_sequences"] = stop | ||
| config_params["stop_sequences"] = stop | ||
| # Handle system instructions | ||
@@ -602,7 +768,12 @@ system_prompts = [msg.content for msg in messages if isinstance(msg, SystemMessage) and msg.content] | ||
| system_prompt_str = "\n\n".join(system_prompts) | ||
| config["system_instruction"] = system_prompt_str | ||
| # Filter out None values before returning | ||
| return {k: v for k, v in config.items() if v is not None} | ||
| config_params["system_instruction"] = system_prompt_str | ||
| # Handle tools if provided (from bind_tools) | ||
| # Wrap FunctionDeclarations in a Tool object as required by the API | ||
| if tools: | ||
| config_params["tools"] = [types.Tool(function_declarations=tools)] | ||
| # Return GenerateContentConfig object | ||
| return types.GenerateContentConfig(**config_params) | ||
| def _trim_for_logging(self, contents: Any) -> Any: | ||
@@ -709,7 +880,33 @@ """Helper to trim large binary data from logging payloads.""" | ||
| # the full usage details for callbacks like Langfuse. | ||
| # Handle content (may be None when tool calls are present) | ||
| content = chunk_response.text or "" | ||
| # Extract tool calls if present | ||
| tool_calls = [] | ||
| if chunk_response.candidates and len(chunk_response.candidates) > 0: | ||
| candidate = chunk_response.candidates[0] | ||
| if candidate.content and candidate.content.parts: | ||
| for part in candidate.content.parts: | ||
| # Check if this part is a function call | ||
| if hasattr(part, 'function_call') and part.function_call: | ||
| func_call = part.function_call | ||
| tool_calls.append({ | ||
| "name": func_call.name, | ||
| "args": dict(func_call.args) if func_call.args else {}, | ||
| "id": (func_call.id if (hasattr(func_call, 'id') and func_call.id) else f"call_{func_call.name}"), | ||
| "type": "tool_call" | ||
| }) | ||
| # Build message kwargs - only include tool_calls if there are any | ||
| message_kwargs = { | ||
| "content": content, | ||
| "response_metadata": {"model_name": self.model_name}, | ||
| } | ||
| if tool_calls: | ||
| message_kwargs["tool_calls"] = tool_calls | ||
| return ChatGenerationChunk( | ||
| message=AIMessageChunk( | ||
| content=chunk_response.text, | ||
| response_metadata={"model_name": self.model_name}, | ||
| ), | ||
| message=AIMessageChunk(**message_kwargs), | ||
| generation_info=None, | ||
@@ -720,5 +917,5 @@ ) | ||
| """Creates a ChatResult with usage metadata for Langfuse tracking.""" | ||
| generated_text = response.text | ||
| generated_text = response.text or "" # Default to empty string if None | ||
| finish_reason = response.candidates[0].finish_reason.name if response.candidates else None | ||
| # Use the new mapping function here for invoke calls | ||
@@ -728,5 +925,22 @@ usage_metadata = self._extract_usage_metadata(response) | ||
| message = AIMessage( | ||
| content=generated_text, | ||
| response_metadata={ | ||
| # Extract tool calls if present | ||
| tool_calls = [] | ||
| if response.candidates and len(response.candidates) > 0: | ||
| candidate = response.candidates[0] | ||
| if candidate.content and candidate.content.parts: | ||
| for part in candidate.content.parts: | ||
| # Check if this part is a function call | ||
| if hasattr(part, 'function_call') and part.function_call: | ||
| func_call = part.function_call | ||
| tool_calls.append({ | ||
| "name": func_call.name, | ||
| "args": dict(func_call.args) if func_call.args else {}, | ||
| "id": (func_call.id if (hasattr(func_call, 'id') and func_call.id) else f"call_{func_call.name}"), | ||
| "type": "tool_call" | ||
| }) | ||
| # Build message kwargs - only include tool_calls if there are any | ||
| message_kwargs = { | ||
| "content": generated_text, | ||
| "response_metadata": { | ||
| "model_name": self.model_name, | ||
@@ -736,4 +950,10 @@ "finish_reason": finish_reason, | ||
| } | ||
| ) | ||
| } | ||
| # Only add tool_calls if there are actual tool calls | ||
| if tool_calls: | ||
| message_kwargs["tool_calls"] = tool_calls | ||
| message = AIMessage(**message_kwargs) | ||
| generation = ChatGeneration( | ||
@@ -743,3 +963,3 @@ message=message, | ||
| ) | ||
| # We also construct the llm_output dictionary in the format expected | ||
@@ -756,3 +976,3 @@ # by LangChain callback handlers, with a specific "token_usage" key. | ||
| ) | ||
| return chat_result | ||
@@ -769,7 +989,8 @@ | ||
| self.logger.info(f"Generating response for {len(messages)} messages.") | ||
| # Remove the problematic add_handler call - callbacks are now handled in invoke methods | ||
| # Extract tools from kwargs if provided (from bind_tools) | ||
| tools = kwargs.pop("tools", None) | ||
| contents = self._convert_messages(messages) | ||
| config = self._prepare_generation_config(messages, stop) | ||
| config = self._prepare_generation_config(messages, stop, tools) | ||
@@ -783,5 +1004,5 @@ try: | ||
| ) | ||
| return self._create_chat_result_with_usage(response) | ||
| except Exception as e: | ||
@@ -800,5 +1021,8 @@ self.logger.error(f"Error generating content with Google GenAI: {e}", exc_info=True) | ||
| self.logger.info(f"Async generating response for {len(messages)} messages.") | ||
| # Extract tools from kwargs if provided (from bind_tools) | ||
| tools = kwargs.pop("tools", None) | ||
| contents = self._convert_messages(messages) | ||
| config = self._prepare_generation_config(messages, stop) | ||
| config = self._prepare_generation_config(messages, stop, tools) | ||
@@ -812,3 +1036,3 @@ try: | ||
| ) | ||
| return self._create_chat_result_with_usage(response) | ||
@@ -829,5 +1053,8 @@ | ||
| self.logger.info(f"Streaming response for {len(messages)} messages.") | ||
| # Extract tools from kwargs if provided (from bind_tools) | ||
| tools = kwargs.pop("tools", None) | ||
| contents = self._convert_messages(messages) | ||
| config = self._prepare_generation_config(messages, stop) | ||
| config = self._prepare_generation_config(messages, stop, tools) | ||
@@ -841,3 +1068,3 @@ try: | ||
| ) | ||
| final_usage_metadata = None | ||
@@ -848,5 +1075,6 @@ for chunk_response in stream: | ||
| if chunk_response.text: | ||
| # Yield chunks that have text or candidates (which may contain tool calls) | ||
| if chunk_response.text or (chunk_response.candidates and chunk_response.candidates[0].content): | ||
| yield self._create_chat_generation_chunk(chunk_response) | ||
| # **FIX:** Yield a final chunk with the mapped usage data | ||
@@ -873,5 +1101,8 @@ if final_usage_metadata: | ||
| self.logger.info(f"Async streaming response for {len(messages)} messages.") | ||
| # Extract tools from kwargs if provided (from bind_tools) | ||
| tools = kwargs.pop("tools", None) | ||
| contents = self._convert_messages(messages) | ||
| config = self._prepare_generation_config(messages, stop) | ||
| config = self._prepare_generation_config(messages, stop, tools) | ||
@@ -885,3 +1116,3 @@ try: | ||
| ) | ||
| final_usage_metadata = None | ||
@@ -891,6 +1122,7 @@ async for chunk_response in stream: | ||
| final_usage_metadata = self._extract_usage_metadata(chunk_response) | ||
| if chunk_response.text: | ||
| # Yield chunks that have text or candidates (which may contain tool calls) | ||
| if chunk_response.text or (chunk_response.candidates and chunk_response.candidates[0].content): | ||
| yield self._create_chat_generation_chunk(chunk_response) | ||
| # **FIX:** Yield a final chunk with the mapped usage data | ||
@@ -903,5 +1135,5 @@ if final_usage_metadata: | ||
| ) | ||
| except Exception as e: | ||
| self.logger.error(f"Error during async streaming: {e}", exc_info=True) | ||
| raise ValueError(f"Error during async streaming: {e}") |
+1
-1
| Metadata-Version: 2.1 | ||
| Name: crewplus | ||
| Version: 0.2.93 | ||
| Version: 0.2.94 | ||
| Summary: Base services for CrewPlus AI applications | ||
@@ -5,0 +5,0 @@ Author-Email: Tim Liu <tim@opsmateai.com> |
+1
-1
@@ -9,3 +9,3 @@ [build-system] | ||
| name = "crewplus" | ||
| version = "0.2.93" | ||
| version = "0.2.94" | ||
| description = "Base services for CrewPlus AI applications" | ||
@@ -12,0 +12,0 @@ authors = [ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
295822
13.33%32
10.34%4664
14.54%